from apo_precond.layers import *
from apo_precond.optimizer import *
import functools
from experiments.optimizers.kfac import KFACOptimizer
from experiments.cifar.utils import *
from experiments.cifar.load_data import *
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
from experiments.cifar.models import *
import argparse
import wandb
import math
import torch
import tqdm

parser = argparse.ArgumentParser()
parser.add_argument("--experiment_name", type=str, default="apo_cifar10")

# Configuration for data loading
parser.add_argument("--data_name", type=str, default="cifar10")
parser.add_argument("--val_data_size", type=int, default=0)

# Configuration for architecture
parser.add_argument("--architecture", type=str, default="alexnet")

# Optimization & Regularization hyperparameters
parser.add_argument("--epochs", type=int, default=200)
parser.add_argument("--optimizer", type=str, default="sgdm")
parser.add_argument("--lr", type=float, default=0.01)
parser.add_argument("--lr_decay_rate", type=float, default=0.5)
parser.add_argument("--lr_decay_schedule", type=list, default=[60, 120, 160])
parser.add_argument("--batch_size", type=int, default=128)
parser.add_argument("--wd", type=float, default=0.0005)

# Parameters for K-FAC
parser.add_argument("--damping", type=float, default=1e-3)
parser.add_argument("--t_cov", default=10, type=int)
parser.add_argument("--t_inv", default=100, type=int)

# Parameters for APO
parser.add_argument("--warmup_step", type=int, default=3000)
parser.add_argument("--apo_precond", type=int, default=0)
parser.add_argument("--precond_lr", type=float, default=0.9)
parser.add_argument("--meta_lr", type=float, default=1e-3)
parser.add_argument("--meta_step", type=int, default=10)
parser.add_argument("--lamb_wsp", type=float, default=1.)
parser.add_argument("--lamb_fsp", type=float, default=1.)
parser.add_argument("--fsp_batch_size", type=int, default=128)

parser.add_argument("--no_cuda", type=bool, default=False)
parser.add_argument("--data_seed", type=int, default=0)
parser.add_argument("--model_seed", type=int, default=0)

parser.add_argument("--checkpoint_dir", type=str, default=None)
parser.add_argument("--save_freq", type=int, default=25)
parser.add_argument("--save_dir", type=str, default=None)

args = parser.parse_args()

cuda = torch.cuda.is_available() and not args.no_cuda
device = torch.device("cuda" if cuda else "cpu")
cudnn.benchmark = True

wandb.init(project=args.experiment_name, config=vars(args))


def adjust_learning_rate(optimizer, epoch, lr, precond_lr, lr_decay_schedule, lr_decay_rate):
    for milestone in lr_decay_schedule:
        lr *= lr_decay_rate if epoch >= milestone else 1.
        if "apo" in str(optimizer.__class__):
            precond_lr *= lr_decay_rate if epoch >= milestone else 1.

    for param_group in optimizer.param_groups:
        param_group["lr"] = lr
        if "apo" in str(optimizer.__class__):
            param_group["precond_lr"] = precond_lr


def get_learning_rate(optimizer):
    lr = []
    for param_group in optimizer.param_groups:
        if "apo" in str(optimizer.__class__):
            lr += [optimizer.get_learning_rate()]
        else:
            lr += [param_group["lr"]]
    return lr


def load_optimizer(optimizer_name, lr, wd):
    optim_map = {
        "sgd": functools.partial(torch.optim.SGD, lr=lr, momentum=0., weight_decay=wd),
        "sgdm": functools.partial(torch.optim.SGD, lr=lr, momentum=0.9, weight_decay=wd),
        "adam": functools.partial(torch.optim.Adam, lr=lr, betas=(0.9, 0.999), weight_decay=wd),
        "rmsprop": functools.partial(torch.optim.RMSprop, lr=lr, momentum=0.9, weight_decay=wd),
    }
    return optim_map[optimizer_name]


def load_kfac_optimizer(lr, wd, damping=0., t_cov=1, t_inv=10):
    return functools.partial(KFACOptimizer, lr=lr, momentum=0.9, weight_decay=wd, damping=damping,
                             TCov=t_cov, TInv=t_inv)


def evaluate(epoch, model, loader, name="valid"):
    top1 = AverageMeter("top1")
    top5 = AverageMeter("top5")
    losses = AverageMeter("losses")
    correct = 0
    total = 0

    targets_list = []
    confidences = []

    model.eval()
    with torch.no_grad():
        p_bar = tqdm.tqdm(loader)
        for inputs, targets in p_bar:
            inputs, targets = inputs.to(device), targets.to(device)
            targets_numpy = targets.cpu().numpy()
            targets_list.extend(targets_numpy.tolist())
            outputs = model(inputs)
            softmax_predictions = F.softmax(outputs, dim=1)
            softmax_predictions = softmax_predictions.cpu().numpy()

            for _values in softmax_predictions:
                confidences.append(_values.tolist())

            _, predicted = torch.max(outputs, 1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            loss = F.cross_entropy(outputs, targets)
            losses.update(loss.item(), inputs.size(0))
            err1, err5 = accuracy(outputs.data, targets, topk=(1, 5))
            top1.update(err1.item(), inputs.size(0))
            top5.update(err5.item(), inputs.size(0))

            p_bar.set_description(
                "Epoch {:d} {} evaluate | loss = {:0.6f}, "
                "top1_acc = {:0.4f}, top5_acc = {:0.4f}, correct/total: {}/{}".format(
                    epoch,
                    name,
                    losses.avg,
                    top1.avg,
                    top5.avg,
                    correct, total
                )
            )

    ece, aurc, eaurc = metric_ece_aurc_eaurc(confidences,
                                             targets_list,
                                             bin_size=0.1)
    log_dict = {
        "epoch": epoch,
        "{}/loss".format(name): losses.avg,
        "{}/top1".format(name): top1.avg,
        "{}/top5".format(name): top5.avg,
        "{}/ece".format(name): ece,
        "{}/aurc".format(name): aurc,
        "{}/eaurc".format(name): eaurc,
        "{}/correct".format(name): correct,
        "{}/total".format(name): total,
    }
    wandb.log(log_dict)

    return top1.avg


def train(epoch, model, loader, fsp_sampler, optimizer):
    train_top1 = AverageMeter("top1")
    train_top5 = AverageMeter("top5")
    train_losses = AverageMeter("losses")
    train_meta_losses = AverageMeter("meta_losses")
    train_wsp = AverageMeter("wsp")
    train_fsp = AverageMeter("fsp")
    correct = 0
    total = 0

    model.train()
    current_lr = get_learning_rate(optimizer)[0]

    p_bar = tqdm.tqdm(loader)
    for inputs, targets in p_bar:
        inputs, targets = inputs.to(device), targets.to(device)

        if args.apo_precond and (optimizer.global_step <= args.warmup_step
                                 or optimizer.global_step % args.meta_step == 0):
            outputs = model(inputs)
            loss = F.cross_entropy(outputs, targets)
            grads = torch.autograd.grad(loss, model.parameters(), retain_graph=True)
            fsp_inputs, _ = fsp_sampler.next()
            fsp_outputs = model(fsp_inputs)
            apo_res = optimizer.meta_step_eval((inputs,), targets, F.cross_entropy, grads=grads,
                                               fsp_outputs=fsp_outputs, fsp_inputs=(fsp_inputs,))
            train_meta_losses.update(apo_res["meta_loss"], inputs.size(0))
            train_wsp.update(apo_res["prox_wsp"], inputs.size(0))
            train_fsp.update(apo_res["prox_fsp"], inputs.size(0))
        else:
            outputs = model(inputs)
        loss = F.cross_entropy(outputs, targets)

        if args.optimizer in ["kfac"] and optimizer.steps % optimizer.TCov == 0:
            optimizer.acc_stats = True
            with torch.no_grad():
                sampled_y = \
                    torch.multinomial(torch.nn.functional.softmax(outputs.cpu().data, dim=1), 1).squeeze().cuda()
            loss_sample = F.cross_entropy(outputs, sampled_y)
            loss_sample.backward(retain_graph=True)
            optimizer.acc_stats = False

        train_losses.update(loss.item(), inputs.size(0))
        err1, err5 = accuracy(outputs.data, targets, topk=(1, 5))
        train_top1.update(err1.item(), inputs.size(0))
        train_top5.update(err5.item(), inputs.size(0))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        _, predicted = torch.max(outputs, 1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        p_bar.set_description(
            "Epoch {:d} train | lr = {:.1e}, loss = {:0.6f}, meta_loss = {:0.6f}, "
            "wsp = {:0.6f}, fsp = {:0.6f} top1_acc = {:0.4f}, top5_acc = {:0.4f}, correct/total: {}/{}".format(
                epoch,
                current_lr,
                train_losses.avg,
                train_meta_losses.avg,
                train_wsp.avg,
                train_fsp.avg,
                train_top1.avg,
                train_top5.avg,
                correct, total
            )
        )

    log_dict = {
        "epoch": epoch,
        "lr": current_lr,
        "train/loss": train_losses.avg,
        "train/meta_loss": train_meta_losses.avg,
        "train/top1": train_top1.avg,
        "train/top5": train_top5.avg,
        "train/wsp": train_wsp.avg,
        "train/fsp": train_fsp.avg,
        "train/correct": correct,
        "train/total": total
    }
    wandb.log(log_dict)

    if math.isnan(train_losses.avg):
        raise ValueError("The Loss is NaN. Exiting the training...")


def train_network(model, loader, val_loader, test_loader, fsp_sampler, optimizer, epochs, save_freq=100):
    epoch = 0

    while epoch < epochs + 1:
        adjust_learning_rate(optimizer, epoch, args.lr, args.precond_lr, args.lr_decay_schedule, args.lr_decay_rate)
        train(epoch, model, loader, fsp_sampler, optimizer)
        if val_loader is not None:
            evaluate(epoch, model, val_loader, name="valid")
        evaluate(epoch, model, test_loader, name="test")
        epoch += 1


def main():
    seed_everything(args.data_seed)
    train_loader, val_loader, test_loader = load_data(dataset_name=args.data_name,
                                                      batch_size=args.batch_size,
                                                      val_data_size=args.val_data_size,
                                                      data_augment=True)

    seed_everything(args.data_seed)
    fsp_train_loader, _, _ = load_data(dataset_name=args.data_name,
                                       batch_size=args.fsp_batch_size,
                                       val_data_size=args.val_data_size,
                                       data_augment=True)
    if args.data_name == "cifar10":
        num_classes = 10
    elif args.data_name == "cifar100":
        num_classes = 100
    else:
        raise Exception("Invalid dataset received")

    seed_everything(args.model_seed)
    if args.architecture == "lenet":
        model = LeNet(num_classes=num_classes)
    elif args.architecture == "alexnet":
        model = alexnet(num_classes=num_classes)
    elif args.architecture == "vgg16":
        model = vgg16(num_classes=num_classes)
    elif args.architecture == "resnet18":
        model = resnet18(num_classes=num_classes)
    else:
        raise NotImplementedError("Architecture {} not implemented".format(args.architecture))
    if args.optimizer not in ["kfac", "ekfac"]:
        replace_layers(model)
    model = model.to(device)

    if args.optimizer in ["sgd", "sgdm", "rmsprop", "adam"]:
        optimizer = load_optimizer(args.optimizer, lr=args.lr, wd=args.wd)(model.parameters())
    elif args.optimizer in ["kfac"]:
        optimizer = load_kfac_optimizer(lr=args.lr, wd=args.wd, damping=args.damping,
                                        t_cov=args.t_cov, t_inv=args.t_inv)(model)
    else:
        raise NotImplementedError("Optimizer {} not implemented".format(args.optimizer))

    if args.apo_precond:
        optimizer = ApoPrecondOptimizer(model, lr=args.lr, weight_decay=args.wd,
                                        meta_lr=args.meta_lr, warmup_step=args.warmup_step,
                                        lamb_wsp=args.lamb_wsp, lamb_fsp=args.lamb_fsp,
                                        initial_optimizer=args.optimizer)

    fsp_sampler = DataSampler(fsp_train_loader, device)
    train_network(model, train_loader, val_loader, test_loader, fsp_sampler, optimizer, args.epochs, args.save_freq)


if __name__ == "__main__":
    main()
