import os
import sys
import glob
sys.path.append('three_regimes_on_the_sphere')


import torch
import numpy as np
import torchvision.transforms as T
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10, CIFAR100
from tqdm.auto import tqdm

import nets
from fourier_slicer import FourierSlicer


lrs = None
ckpts_by_lr = None
assert lrs is not None, 'specify a list of PLRs'
assert ckpts_by_lr is not None, 'specify a dict of checkpoints: PLR -> list of checkpoints

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f'Using device {device})


model = nets.resnet_si.make_resnet18k(k=32, num_classes=10).to(device)
def load_model(path):
    ckpt = torch.load(path, map_location=torch.device('cpu'))
    model.load_state_dict(ckpt["state_dict"])


test_set = CIFAR10(
    '~/datasets/cifar10/', train=False,
    transform=T.Compose(nets.resnet_si.ResNet18SI.transform_test.transforms[:2]),
    download=True
)
normalize = nets.resnet_si.ResNet18SI.transform_test.transforms[2]
test_loader = DataLoader(test_set, batch_size=2048, num_workers=8, shuffle=False, pin_memory=True)
slicer = FourierSlicer(32, blocks=[(0, 1), (1, 9), (9, 25), (25, 33)], pres_low_freq=0, mask_mode=False)


test_accs = np.zeros((len(lrs), len(ckpts_by_lr[lrs[0]]), len(slicer.blocks) + 1))
for i, lr in enumerate(tqdm(lrs)):
    ckpts = ckpts_by_lr[lr]

    for k, ckpt in enumerate(ckpts):
        load_model(ckpt)
        model.eval()

        all_preds = []
        for images, _ in test_loader:
            images = images.to(device)

            with torch.no_grad():
                preds = model(normalize(images)).argmax(dim=-1).cpu()
            cur_preds = [preds]

            for j, rec_images in enumerate(slicer(images)):
                with torch.no_grad():
                    if (j == 0 and not slicer.mask_mode) or (j > 0 and slicer.mask_mode):
                        rec_images = normalize(rec_images)
                    preds = model(rec_images).argmax(dim=-1).cpu()
                    cur_preds.append(preds)

            all_preds.append(torch.stack(cur_preds, dim=0))

        all_preds = torch.cat(all_preds, dim=1)
        test_accs[i, k] = (
            all_preds == torch.tensor(test_set.targets).reshape(1, -1)
        ).to(torch.float).mean(dim=1).numpy()

np.save(save_path, test_accs)
