from __future__ import print_function
import logging
import os
import sys
import datetime
import time
import random
import numpy as np
import argparse
import copy
import pickle
from tqdm import tqdm, trange

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable

from utils import *

EPS = 1e-24

def seed_everything(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)  # type: ignore
    #torch.backends.cudnn.deterministic = True  # type: ignore
    torch.backends.cudnn.benchmark = True  # type: ignore


def train(args):
    
    # logger
    data_augmentation_str = ''
    if args.data_aug:
        data_augmentation_str = '_dataaug'
    label_smoothing_str = ''
    if args.label_smoothing:
        label_smoothing_str = '_labelsmoothing'
    lrschedule_str = ''
    use_lrs = True
    if args.lr_scheduler =='multistep' and args.gamma != 1.0:
        lrschedule_str = '_multistepLR'+str(args.gamma)
    elif args.lr_scheduler =='cosine':
        lrschedule_str = '_cosine'
    elif args.lr_scheduler =='cosine_warmup':
        lrschedule_str = '_cosine_warmupT0_'+str(args.T_0)
    else:
        use_lrs = False
    
    os.makedirs('./'+args.saving_folder, exist_ok=True)
    loggingFileName = './'+args.saving_folder+args.dataset+'_'\
    +args.model+'_SGD_baseline'+'_lr'+str(args.lr)+lrschedule_str\
    +'_mom'+str(args.momentum)+'_wd'+str(args.weight_decay)+'_bs'+str(args.batch_size)+'_seed'+str(args.seed)
    
    loggingFileName = loggingFileName+data_augmentation_str\
        +label_smoothing_str+'_'+time_now
    
    logger = logging.getLogger(__name__)
    logging.basicConfig(filename=os.path.join(loggingFileName+'.log'),
                        format='[%(asctime)s] - %(message)s',
                        datefmt='%Y/%m/%d %H:%M:%S',
                        level=logging.DEBUG
                       )
    logger.info(args)

    # get dataset / model / optimizer / criterion / lr_scheduler
    train_loader, test_loader = get_data(dataset=args.dataset,
                                         train_bs=args.batch_size,
                                         test_bs=args.test_batch_size,
                                         data_augmentation=args.data_aug,
                                         normalization=True,
                                         shuffle=True,
                                         cutout=args.cutout,
                                         model = args.model
                                        )    
    print(len(train_loader.dataset))
    
    args.epochs = (args.max_iteration-1) // (len(train_loader.dataset) // args.batch_size) + 1
    print("max epoch: ", args.epochs)
    
    model = get_model(args.model,
                      dataset=args.dataset,
                      num_classes=args.num_classes
                     )
    if args.cuda:
        model = model.cuda()
    if args.parallel:
        model = torch.nn.DataParallel(model)
    
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay = args.weight_decay)
    
    if args.label_smoothing:
        smoothing = 0.1
        criterion = smooth_CrossEntropyLoss(smoothing=smoothing)
        criterion_alldata = smooth_CrossEntropyLoss(smoothing=smoothing, reduction='none')
    else:
        criterion = torch.nn.CrossEntropyLoss()
        criterion_alldata = torch.nn.CrossEntropyLoss(reduction = 'none')
    
    
    
    if args.lr_scheduler =='multistep' and args.gamma != 1.0:
        lrs = MultiStepLR(optimizer, args.milestones, gamma=args.gamma) 
    elif args.lr_scheduler =='cosine':
        lrs = CosineAnnealingLR(optimizer, args.epochs)
    elif args.lr_scheduler =='cosine_warmup':
        optimizer = torch.optim.SGD(model.parameters(), lr=args.lr*0.01, momentum=args.momentum, weight_decay = args.weight_decay)
        lrs = CosineAnnealingWarmUpRestarts(optimizer, args.T_0, T_mult=args.T_mult, eta_max=args.lr, 
                                            T_up=args.T_up, gamma=args.gamma)
        
        
    #####    training starts here     #####
    
    # values to save
    train_loss_traj = []
    train_acc_traj = []

    test_loss_traj = []
    test_acc_traj = []

    lr_traj = []
    
    timer_train = 0
    timer_backward = 0
    timer_eig_calc = 0
    timer_test = 0
    iter_num = 0
    isnan = False
    best_test_acc = 0
    best_test_acc_iter = 0
    cur_stop_idx = 0
    cur_stop_loss = args.stop_loss[cur_stop_idx]
    for i in range(args.epochs):
        N = 0
        cur_loss = 0
        num_correct = 0
        with tqdm(total=len(train_loader.dataset)) as progressbar:
            for ii, (data_ii, label_ii) in enumerate(train_loader, 0):

                model.train()
                X, y = data_ii.cuda(), label_ii.cuda()
                
                N_ii = data_ii.shape[0]
                N += N_ii
                
                start_time = time.time()
                
                output = model(X)
                loss = criterion(output, y)
                cur_loss += loss.item() * N_ii

                # calculate accuracy
                pred = torch.argmax(output, axis=1)
                num_correct += torch.sum(pred == y).item()
                
                # progressbar update
                progressbar.set_postfix(loss=cur_loss/N,
                                        acc=100. * num_correct / N,
                                        epoch=i,
                                       lr=optimizer.__dict__['param_groups'][0]['lr'])
                progressbar.update(y.size(0))
                
                # backward
                temp_start = time.time()
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                timer_backward += time.time() - temp_start

                temp_start = time.time()
                
                # test
                if iter_num == 0 or (iter_num + 1) % args.test_period == 0:
                    test_loss, test_acc = test(model, test_loader, criterion)
                    test_loss_traj.append(test_loss)
                    test_acc_traj.append(test_acc)

                    if test_acc > best_test_acc:
                        best_test_acc = test_acc
                        best_test_acc_iter = iter_num
                timer_test += time.time() - temp_start

                timer_train += time.time() - start_time
                iter_num += 1
            
        train_loss_traj.append(cur_loss/N)
        train_acc_traj.append(num_correct/N)
        lr_traj.append(float(optimizer.__dict__['param_groups'][0]['lr']))
        log_list = [i+1,
                timer_train,
                optimizer.__dict__['param_groups'][0]['lr'],
                train_loss_traj[-1],
                train_acc_traj[-1],
                test_loss_traj[-1],
                test_acc_traj[-1],
                best_test_acc,
                best_test_acc_iter]
        log_input = (*log_list,)
        logger.info("epoch: %3d, time: %7.1f, lr: %.4f, train loss: %.4f, train acc: %.4f, test loss: %.4f, test acc: %.4f, best test acc so far: %.4f at iter %d"%log_input)
        if use_lrs:
            lrs.step()
        
        # save models when loss is lower than stop_loss values
        if train_loss_traj[-1] < cur_stop_loss:
            model_state = copy.deepcopy(model.state_dict())
            PATH = './'+args.saving_folder+loggingFileName.split('/')[-1]+'_loss'+str(cur_stop_loss)+'.pth'
            print('loss is less than '+str(cur_stop_loss)+'... save the model in '+PATH)
            torch.save(model_state, PATH)
            cur_stop_idx += 1
            if cur_stop_idx == len(args.stop_loss):
                print('training finished...')
                break
            cur_stop_loss = args.stop_loss[cur_stop_idx]
        
    logger.info("time for train: %.4f, backward: %.4f, test: %.4f, eig: %.4f"%(timer_train, 
                                                                               timer_backward, 
                                                                               timer_test, 
                                                                               timer_eig_calc))
    
    logger.info("max train acc: %.4f"%(max(train_acc_traj)))
    logger.info("max test acc: %.4f"%(max(test_acc_traj)))
    
    """
    filename = loggingFileName+'.pickle' 
    
    res = {
           'train_loss_traj'   : train_loss_traj,
           'train_acc_traj'    : train_acc_traj,
           'test_loss_traj'    : test_loss_traj,
           'test_acc_traj'     : test_acc_traj
          }

    with open(filename, 'wb') as handle:
        pickle.dump(res, handle, protocol=pickle.HIGHEST_PROTOCOL)
    logger.info("values saved at " + filename)
    """
    if i == args.epochs-1:
        model_state = copy.deepcopy(model.state_dict())
        PATH = './'+args.saving_folder+loggingFileName.split('/')[-1]+'_epoch'+str(args.epochs)+'.pth'
        print(PATH)
        torch.save(model_state, PATH)
    
    
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Training ')
    parser.add_argument('--gpu',
                        type=int,
                        default=0,
                        help='choose gpu number')
    parser.add_argument('--method',
                        type=str,
                        default='SGD',
                        help='SGD')

    parser.add_argument('--dataset',
                        type=str,
                        default='cifar10',
                        help='cifar10/mnist/cifar100')
    parser.add_argument('--batch-size',
                        type=int,
                        default=128,
                        help='input batch size for training (default: 128)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=500,
                        help='input batch size for testing (default: 1024)')
    parser.add_argument('--model',
                        type=str,
                        default='lenet',
                        help='lenet/vgg11_bn/3FCN/resnet20/resnet56/WRN164/WRN168')
    parser.add_argument('--epochs',
                        type=int,
                        default=1000,
                        help='number of epochs to train (default: 200)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.1,
                        help='learning rate (default: 0.1)')
    parser.add_argument('--momentum',
                        default=0.9,
                        type=float,
                        help='momentum (default: 0.9)') 
    parser.add_argument('--weight-decay',
                        default=0.0005,
                        type=float,
                        help='weight decay (default: 0.0)') 

    parser.add_argument('--data_aug',
                        default=False,
                        type=bool,
                        help='data augmentation (default: False)') 
    
    parser.add_argument('--no-cuda',
                        action='store_true',
                        help='do we use gpu or not')
    parser.add_argument('--no-parallel',
                        action='store_true',
                        help='do we use parallel or not') 

    parser.add_argument('--saving-folder',
                        type=str,
                        default='trained_model_cor_exp/',
                        help='choose saving name')
    parser.add_argument('--savemodels',
                        action='store_true',
                        help='save models')
    parser.add_argument('--name',
                        type=str,
                        default='noname',
                        help='choose saving name')
    parser.add_argument('--no-overwrite',
                        action='store_true',
                        help='do we rewrite or not')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        help='random seed (default: 1)')

    parser.add_argument('--criterion',
                        type=str,
                        default='cross-entropy',
                       help='cross-entropy/mse/label_smoothing')

    parser.add_argument('--lr-scheduler',
                        type=str,
                        default='multistep',
                        help='cosine/multistep/cosine_warmup/')

    parser.add_argument("--milestones", nargs='*', type=int, default=None)
    parser.add_argument("--gamma", type=float, default=1.0)

    parser.add_argument('--test_period',
                        type=int,
                        default=200)

    parser.add_argument('--no-grad-normalize',
                        action='store_true',
                        help='use the unnormlized rho (with the gradient norm)')

    parser.add_argument('--no-label-smoothing',
                        action='store_true',
                        help='label smoothing')
    parser.add_argument("--smoothing", type=float, default=0.1)

    
    parser.add_argument('--cutout',
                        action='store_true',
                        help='do we use cutout or not') 

    parser.add_argument("--T_0", type=int, default=200, help="the initial period of cosine warmup scheduler")
    parser.add_argument("--T_up", type=int, default=10, help="the epochs to warm up for cosine warmup scheduler")
    parser.add_argument("--T_mult", type=int, default=1, 
                        help="the constant multiplied to the period after each period for cosine warmup scheduler")
    parser.add_argument("--seeds", nargs='*', type=int, default=None)
    parser.add_argument("--max_iteration", type=int, default=1000000, help="the number of maximum iterations")
    parser.add_argument("--stop_loss", nargs='*', type=float, default=[0.1])
    args = parser.parse_args()
    
    ## default: ON
    args.grad_normalize = not (args.no_grad_normalize)
    args.label_smoothing = not (args.no_label_smoothing)
    args.cuda = not (args.no_cuda)
    args.overwrite = not (args.no_overwrite)
    args.parallel = not (args.no_parallel)
    
    ## modification
    if args.label_smoothing:
        args.criterion = 'label_smoothing'
        
    if args.dataset == 'cifar10' or args.dataset == 'mnist' or args.dataset == 'mnist7x7':
        args.num_classes = 10
    elif args.dataset == 'cifar100':
        args.num_classes = 100
    else:
        raise ValueError("Unknown dataset")
        
    # time
    time_now = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
    print(time_now)
    
    for arg in vars(args):
        print(arg, getattr(args, arg))
# --method 'ASAM' --no-parallel --name 'ASAM_0.5' --milestones 60 120 160 --rho 0.5 --no-label-smoothing        
    
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
    # set random seed to reproduce the work
    if args.seeds is None:
        seed_everything(args.seed)
        train(args)
    else:
        print(args.seeds)
        for seed in args.seeds:
            args.seed = seed
            seed_everything(args.seed)
            train(args)
