# Adapting Preconditioning Matrix with Amortized Proximal Optimization

Clean implementation on PyTorch. Note that the PyTorch Dataloader introduces an additional computational overhead compared to the JAX implementation.

## Example

### AlexNet on CIFAR-10

The code is located at `experiments/cifar`.

**CIFAR-10 Baseline**
```
python train.py  \
    --lr 0.01  \
    --optimizer sgdm \
    --architecture alexnet \
    --wd 0.0005 \
    --data_name cifar10
```

**CIFAR-10 APO**
```
python train.py  \
    --lr 0.01  \
    --meta_lr 0.0001 \
    --precond_lr 0.9 \
    --lamb_wsp 0.1 \
    --lamb_fsp 0.3 \
    --optimizer sgdm \
    --apo_precond 1 \
    --architecture alexnet \
    --wd 0.0005 \
    --data_name cifar10
```
