import argparse
import logging
import os
import time

import torch
import yaml

import models
import tools


logger = logging.getLogger(__name__)


def get_args():
    parser = argparse.ArgumentParser(
        'Training Globally-Robust Neural Networks')

    parser.add_argument('--config',
                        type=str,
                        help='path to the config yaml file')
    parser.add_argument('--depth', type=int, const=None, nargs='?', help='override for model depth')
    parser.add_argument('--width', type=int, const=None, nargs='?', help='override for model width')
    parser.add_argument('--epochs', type=int, const=None, nargs='?', help='override for training epochs')
    # checkpoint saving
    parser.add_argument('--work_dir', default='./checkpoint/', type=str)
    parser.add_argument('--ckpt_prefix', default='', type=str)
    parser.add_argument('--max_save', default=3, type=int)
    # distributed training
    parser.add_argument('--launcher',
                        default='slurm',
                        type=str,
                        help='should be either `slurm` or `pytorch`')
    parser.add_argument('--local_rank', type=int, default=0)

    # Auxiliary specifications
    parser.add_argument('--auxiliary-dir', default=None, type=str)
    parser.add_argument('--auxiliary', default=None, type=str)
    parser.add_argument('--fraction', default=0.7, type=float)

    return parser.parse_args()


def main():
    args = get_args()

    with open(args.config, 'r') as f:
        cfg = yaml.load(f, Loader=yaml.Loader)

    model_cfg = cfg['model']
    train_cfg = cfg['training']
    dataset_cfg = cfg['dataset']
    gloro_cfg = cfg['gloro']

    # Process config overrides
    if args.depth:
        model_cfg['depth'] = args.depth
    if args.width:
        model_cfg['width'] = args.width
    if args.epochs:
        train_cfg['epochs'] = args.epochs

    if args.ckpt_prefix == '':
        depth, width = model_cfg['depth'], model_cfg['width']
        prefix = f"{model_cfg['arch']}-{depth}x{width}_{dataset_cfg['name']}-{args.auxiliary}-{int(100 * args.fraction):d}-e{train_cfg['epochs']}"
        args.ckpt_prefix = prefix

    # torch.backends.cudnn.benchmark = True
    os.system(f'cat {args.config}')

    print(f'Use checkpoint prefix: {args.ckpt_prefix}')

    if args.auxiliary_dir and args.auxiliary and args.auxiliary != 'None':
        aux_path = os.path.join(args.auxiliary_dir, '{}.npz'.format(args.auxiliary))
    else:
        aux_path = None
    train_loader, _, val_loader, _ = tools.data_loader(
        data_name=dataset_cfg['name'],
        batch_size=train_cfg['batch_size'],
        num_classes=dataset_cfg['num_classes'],
        auxiliary=aux_path,
        fraction=args.fraction,
        seed=dataset_cfg.get('seed', 2023))  # if seed is not given, use 2023

    model = models.GloroNet(**model_cfg, **dataset_cfg)
    print(model)
    model = model.cuda()

    optimizer = torch.optim.AdamW(model.parameters(),
                                  lr=train_cfg['lr'],
                                  weight_decay=train_cfg['weight_decay'])
    scheduler = tools.lr_scheduler(iter_per_epoch=len(train_loader),
                                   max_epoch=train_cfg['epochs'],
                                   warmup_epoch=train_cfg['warmup_epochs'])

    def eps_fn(epoch):
        ratio = min(epoch / train_cfg['epochs'] * 2, 1)
        ratio = gloro_cfg['min_eps'] + (gloro_cfg['max_eps'] -
                                        gloro_cfg['min_eps']) * ratio
        return gloro_cfg['eps'] * ratio

    os.makedirs(args.work_dir, exist_ok=True)

    if gloro_cfg['trades_loss']:
        train_fn = models.margin_layer
    else:
        train_fn = models.margin_layer_v2

    print('Begin Training')
    logfile = os.path.join(args.work_dir, f'{args.ckpt_prefix}.log')
    if os.path.exists(logfile):
        os.remove(logfile)
    logging.basicConfig(format='%(message)s', level=logging.INFO, filename=logfile)
    logger.info('Epoch \t Train Acc \t Train Robust \t Test Acc \t Test Robust \t Sub Lipschitz \t Time')

    training_logs = []
    t = time.time()
    for epoch in range(train_cfg['epochs']):
        eps = eps_fn(epoch)
        model.set_num_lc_iter(model_cfg['num_lc_iter'])

        model.train()
        correct_vra = correct = total = 0.
        for inputs, targets in train_loader:
            optimizer.zero_grad()
            inputs = inputs.cuda(non_blocking=True)
            targets = targets.cuda(non_blocking=True)

            y, y_, loss = train_fn(model,
                                   x=inputs,
                                   label=targets,
                                   eps=eps,
                                   use_lln=model_cfg['use_lln'],
                                   return_loss=True)

            _ = scheduler.step(optimizer)
            loss.backward()
            if train_cfg['grad_clip']:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               train_cfg['grad_clip_val'])

            optimizer.step()

            correct += y.argmax(1).eq(targets).sum().item()
            correct_vra += y_.argmax(1).eq(targets).sum().item()
            total += targets.size(0)

        model.set_num_lc_iter(500)  # let the power method converge
        # only need to comput the sub_lipschitz only once for validation
        sub_lipschitz = model.sub_lipschitz().item()

        val_correct_vra = val_correct = val_total = 0.
        model.eval()
        for inputs, targets in val_loader:
            inputs = inputs.cuda(non_blocking=True)
            targets = targets.cuda(non_blocking=True)
            with torch.no_grad():
                y, y_, _ = models.margin_layer(model,
                                               x=inputs,
                                               label=targets,
                                               eps=gloro_cfg['eps'],
                                               use_lln=model_cfg['use_lln'],
                                               subL=sub_lipschitz,
                                               return_loss=False)

            val_correct += y.argmax(1).eq(targets).sum().item()
            val_correct_vra += y_.argmax(1).eq(targets).sum().item()
            val_total += targets.size(0)

        collect_info = [
            correct_vra, correct, total, val_correct_vra, val_correct,
            val_total
        ]
        collect_info = torch.tensor(collect_info,
                                    dtype=torch.float32,
                                    device=inputs.device).clamp_min(1e-9)

        acc_train = 100. * collect_info[1] / collect_info[2]
        acc_val = 100. * collect_info[4] / collect_info[5]

        acc_vra_train = 100. * collect_info[0] / collect_info[2]
        acc_vra_val = 100. * collect_info[3] / collect_info[5]

        used = time.time() - t
        t = time.time()

        # string = (f'Epoch {epoch}: '
        #           f'Train acc{acc_train: .2f}%,{acc_vra_train: .2f}%; '
        #           f'val acc{acc_val: .2f}%,{acc_vra_val: .2f}%. '
        #           f'sub_lipschitz:{sub_lipschitz: .2f}. '
        #           f'Time:{used / 60: .2f} mins.')
        # print(string)

        logger.info('%d \t %.2f \t %.2f \t %.2f \t %.2f \t %.2f \t %.2f',
                    epoch, acc_train, acc_vra_train, acc_val, acc_vra_val, sub_lipschitz, used / 60)

        state = dict(backbone=model.state_dict(),
                     optimizer=optimizer.state_dict(),
                     start_epoch=epoch + 1,
                     current_iter=scheduler.current_iter,
                     training_logs=training_logs,
                     configs=cfg)

        if epoch == train_cfg['epochs'] - 1:
            lc_dict = {}
            for name, module in model.named_modules():
                if name and name != 'head':
                    lc = module.lipschitz()
                    if isinstance(lc, torch.Tensor):
                        lc = lc.item()
                    lc_dict[name] = lc

            state['lc_dict'] = lc_dict

        try:
            path = f'{args.work_dir}/{args.ckpt_prefix}_{epoch}.pth'
            torch.save(state, path)
        except PermissionError:
            print('Error saving checkpoint!')
            pass
        if epoch >= args.max_save:
            path = (f'{args.work_dir}/'
                    f'{args.ckpt_prefix}_{epoch - args.max_save}.pth')
            os.system('rm -f ' + path)


if __name__ == '__main__':
    main()
