from dataset.ImageNet_LT import ImageNetLTDataLoader
from models.utils import load_pretrained_weights
import os
import numpy as np
import random
from numpy.lib.scimath import log
import torch
import ignite
from torch._C import dtype

import matplotlib.pyplot as plt
import torchvision.utils as vutils
from utils.metrics import print_num_params
from core.trainer import eval_epoch
from core.trainer_DA import train_epoch_DA as train_epoch
from core.utils import loss_adjust_cross_entropy,cross_entropy, logit_adjust_ly
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
assert torch.cuda.is_available()
assert torch.backends.cudnn.enabled
torch.backends.cudnn.benchmark = True
device = "cuda"

# seed = 17 # random seed
# random.seed(seed)
# _ = torch.manual_seed(seed)
import argparse
parser=argparse.ArgumentParser()
parser.add_argument('--model', dest='model', default='ResNet32', type=str)
#parser.add_argument('--model', dest='model', default='ResNet50_INAT', type=str)
parser.add_argument('--dataset', dest='dataset', default='Cifar100', type=str)
#parser.add_argument('--dataset', dest='dataset', default='INAT', type=str)
parser.add_argument('--batch_size', dest='batch_size', default=128, type=int)
parser.add_argument('--lr', dest='lr', default=0.1, type=float)
parser.add_argument('--arch_lr', dest='arch_lr', default=0.01, type=float)
parser.add_argument('--checkpoint_interval', dest='checkpoint_interval', default=40, type=int)
parser.add_argument('--model_file', dest='model_file', default=None, type=str)
parser.add_argument('--save_path', dest='save_path', default=None, type=str)
parser.add_argument('--epoch', dest='epoch', default=300, type=int)
parser.add_argument('--train_rho', dest='train_rho', default=0.01, type=float)
parser.add_argument('--ARCH_EPOCH', dest='ARCH_EPOCH', default=0, type=int)
parser.add_argument('--ARCH_END', dest='ARCH_END', default=1000, type=int)
parser.add_argument('--ARCH_INTERVAL', dest='ARCH_INTERVAL', default=10, type=int)
parser.add_argument('--ARCH_TRAIN_SAMPLE', dest='ARCH_TRAIN_SAMPLE', default=10, type=int)
parser.add_argument('--ARCH_VAL_SAMPLE', dest='ARCH_VAL_SAMPLE', default=10, type=int)
parser.add_argument('--ARCH_EPOCH_INTERVAL', dest='ARCH_EPOCH_INTERVAL', default=1, type=int)
parser.add_argument('--dy', dest='dy', default='True', type=str)
parser.add_argument('--ly', dest='ly', default='True', type=str)
parser.add_argument('--checkpoint', dest='checkpoint', default=0, type=int)
parser.add_argument('--ly_init', dest='ly_init',default='Zeros',type=str)
parser.add_argument('--dy_init', dest='dy_init',default='Ones',type=str)
parser.add_argument('--wy_init', dest='wy_init',default='Ones',type=str)
parser.add_argument('--ly_tau',dest='ly_tau',default=1., type=float)
parser.add_argument('--wy', dest='wy', default='False', type=str)
parser.add_argument('--group_size', dest='group_size', default=10, type=int)
parser.add_argument('--train_size', dest='train_size', default=4000, type=int)
parser.add_argument('--val_size', dest='val_size', default=1000, type=int)
parser.add_argument('--balance_val', dest='balance_val', default='True', type=str)
args=parser.parse_args()

args.dy=args.dy=='True'
args.ly=args.ly=='True'
args.balance_val=args.balance_val=='True'
network_model=args.model # Model, either ResNet20 or Efficient
dataset=args.dataset  # Dataset, either Cifar10 or Cifar100
batch_size=args.batch_size # Training batchsize
lr = args.lr  # inner optim lr
arch_lr=args.arch_lr  # outer optim lr
total_epoch=args.epoch # Total training epoch
train_rho=args.train_rho # Imbalance ratio : Min/Max

ARCH_EPOCH=args.ARCH_EPOCH # The epoch for starting outer opimization
ARCH_END=args.ARCH_END # The epoch for ending outer opimization
ARCH_INTERVAL=args.ARCH_INTERVAL # The iteration interval for conduction hyper-parameter update
ARCH_TRAIN_SAMPLE=args.ARCH_TRAIN_SAMPLE # The batches of training samples used for one arch update
ARCH_VAL_SAMPLE=args.ARCH_VAL_SAMPLE # The batches of validation samples used for one arch update
ARCH_EPOCH_INTERVAL=args.ARCH_EPOCH_INTERVAL # 

if dataset=='Cifar10':
        from dataset.cifar10 import load_cifar10 as load_dataset
        num_classes=10
elif dataset=='Cifar100':
        from dataset.cifar100 import load_cifar100 as  load_dataset
        num_classes=100
elif dataset=='ImageNet':
        from dataset.ImageNet_LT import ImageNetLTDataLoader as  load_dataset
        num_classes=1000
elif dataset=='INAT':
        from dataset.iNaturalist import INAT as  load_dataset
        num_classes=8142


if network_model=='Efficient':
        from models.EfficientNet_NEW import EfficientNet
        model=EfficientNet.from_pretrained('efficientnet-b0',load_weights=False,num_classes=num_classes)
        train_loader,val_loader,test_loader,eval_train_loader,eval_val_loader,num_train_samples,num_val_samples=load_dataset(batch_size=batch_size,train_rho=train_rho)
elif network_model=='ResNet20':
        from models.ResNet import ResNet20
        model=ResNet20(num_classes=num_classes)
        train_loader,val_loader,test_loader,eval_train_loader,eval_val_loader,num_train_samples,num_val_samples=load_dataset(batch_size=batch_size,train_rho=train_rho,image_size=32)
elif network_model=='ResNet32':
        from models.ResNet import ResNet32
        model=ResNet32(num_classes=num_classes)
        train_loader,val_loader,test_loader,eval_train_loader,eval_val_loader,num_train_samples,num_val_samples=load_dataset(
                balance_val=args.balance_val,val_size=args.val_size,train_size=args.train_size
                ,batch_size=batch_size,train_rho=train_rho,image_size=32)
        #train_loader,val_loader,test_loader,eval_train_loader,eval_val_loader,num_train_samples,num_val_samples=load_dataset(balance_val=True,batch_size=batch_size,train_rho=train_rho,image_size=32)
elif network_model=='ResNet50_Image':
        import torchvision.models as models
        model=models.resnet50(pretrained=False,num_classes=num_classes)
        train_loader=load_dataset('./data/ImageNet_LT/',training=True,batch_size=64)
        val_loader=train_loader.split_validation()
        test_loader=load_dataset('./data/ImageNet_LT/',training=False,batch_size=64)
        num_train_samples,num_val_samples=train_loader.get_train_val_size()
elif network_model=='ResNet50_INAT':
        import torchvision.models as models
        model=models.resnet50(pretrained=False,num_classes=num_classes)
        train_loader=load_dataset('./data/inat_2018/','./data/inat_2018/train2018.json',is_train=True,split=1)
        num_train_samples=train_loader.get_class_size()
        train_loader=DataLoader(train_loader,batch_size=64)
        val_loader=load_dataset('./data/inat_2018/','./data/inat_2018/train2018.json',is_train=True,split=2)
        num_val_samples=val_loader.get_class_size()
        val_loader=DataLoader(val_loader,batch_size=64)
        test_loader=load_dataset('./data/inat_2018/','./data/inat_2018/val2018.json',is_train=False,split=0)

print_num_params(model)

if args.checkpoint!=0:
        model.load_state_dict(torch.load(f'{args.save_path}/epoch_{args.checkpoint}.pth'))

model = model.to(device)

criterion = nn.CrossEntropyLoss()


pi=num_train_samples/np.sum(num_train_samples)
tau=args.ly_tau
pi=tau*log(pi)
print('Google pi: ',pi)

#computed_dy=[1.9738, 1.6991, 1.4067, 1.0689, 0.8530, 0.6294, 0.4417, 0.4953, 0.4201,0.3912]
if args.dy_init=='Ones':
        dy=torch.ones([((num_classes-1)//args.group_size)+1],dtype=torch.float32,device=device)
else:
        file=open(args.dy_init,mode='r')
        dy=file.readline().replace('[','').replace(']','').replace('\n','').split()
        print(dy)
        dy=np.array([float(a) for a in dy])
        dy=torch.tensor(dy,dtype=torch.float32).cuda()

if args.ly_init=='Zeros':
        ly=torch.zeros([((num_classes-1)//args.group_size)+1],dtype=torch.float32,device=device)
elif args.ly_init=='Google':
        ly=torch.tensor(pi,dtype=torch.float32).cuda()
else:
        file=open(args.ly_init,mode='r')
        ly=file.readline().replace('[','').replace(']','').replace('\n','').split()
        ly=np.array([float(a) for a in ly])
        ly=torch.tensor(ly,dtype=torch.float32).cuda()

if args.wy_init=='Ones':
        w_train=torch.ones([num_classes],dtype=torch.float32,device=device)
        #w_val=torch.ones([num_classes],dtype=torch.float32,device=device)
elif args.wy_init=='Pi':
        w_train=np.sum(num_train_samples)/num_train_samples     
        w_train=w_train/np.linalg.norm(w_train)     
        w_train=torch.tensor(w_train,dtype=torch.float32).cuda()     

w_val=np.sum(num_val_samples)/num_val_samples
w_val=w_val/np.linalg.norm(w_val)
w_val=torch.tensor(w_val,dtype=torch.float32).cuda()

dy.requires_grad=args.dy
ly.requires_grad=args.ly
w_train.requires_grad=False
w_val.requires_grad=False

print(w_train,w_val)
print(ly,dy)

from core.augmentation.augmentation import augment_dict
print(augment_dict)
aug_p=torch.ones([len(augment_dict),((num_classes-1)//args.group_size)+1],dtype=torch.float32,device=device)/len(augment_dict)
print(aug_p)
aug_p.requires_grad=True
aug_u=torch.ones([len(augment_dict),((num_classes-1)//args.group_size)+1],dtype=torch.float32,device=device)/len(augment_dict)
print(aug_u)
aug_u.requires_grad=True


train_optimizer = optim.SGD(params=model.parameters(),lr=lr,momentum=0.9,weight_decay=1e-4)
val_optimizer = optim.SGD(params=[{'params':dy},{'params':ly},{'params':aug_p},{'params':aug_u}],
                        lr=arch_lr,momentum=0.9,weight_decay=1e-4)
train_lr_scheduler=optim.lr_scheduler.MultiStepLR(train_optimizer,milestones=[210,270],gamma=0.2)
val_lr_scheduler=optim.lr_scheduler.MultiStepLR(val_optimizer,milestones=[210,270],gamma=0.2)

if args.save_path is None:
        import time
        args.save_path=f'./results/{int(time.time())}'       
if not os.path.exists(args.save_path):
        os.makedirs(args.save_path)
logfile=open(f'{args.save_path}/logs.txt',mode='w')
dy_log=open(f'{args.save_path}/dy.txt',mode='w')
ly_log=open(f'{args.save_path}/ly.txt',mode='w')
acc_log=open(f'{args.save_path}/acc.txt',mode='w')
config_log=open(f'{args.save_path}/config.txt',mode='w')
for k,v in vars(args).items():
	config_log.write(str(k)+' '+str(v)+'\n')
config_log.close()

torch.save(model,f'{args.save_path}/init_model.pth')
for i in range(total_epoch+1):
        
        text,loss,train_acc=eval_epoch(train_loader,model,loss_adjust_cross_entropy,i,' train_dataset',params=[dy,ly,w_train],num_classes=num_classes,class_wise=True,group_size=args.group_size)

        logfile.write(text+'\n')
        text,loss,val_acc=eval_epoch(val_loader,model,cross_entropy,i,' val_dataset',params=[dy,ly,w_val],logit_adjust=None,num_classes=num_classes,class_wise=True,group_size=args.group_size)
        logfile.write(text+'\n')
        text,loss,test_acc=eval_epoch(test_loader,model,cross_entropy,i,' test_dataset',params=[dy,ly],logit_adjust=None,num_classes=num_classes,class_wise=True,group_size=args.group_size)
        logfile.write(text+'\n')
        print(dy,ly,'\n')

        # train_epoch(i, model, 
        #         in_loader=train_loader, in_criterion=cross_entropy, 
        #         in_optimizer=train_optimizer,in_params=[None,None,None,augment_dict,aug_p,aug_u],
        #         is_out=False,
        #         num_classes=num_classes)
        train_epoch(i, model, 
                in_loader=train_loader, in_criterion=loss_adjust_cross_entropy, 
                in_optimizer=train_optimizer,in_params=[dy,ly,w_train,aug_p,aug_u],
                is_out=(i>=ARCH_EPOCH) and (i<=ARCH_END) and ((i+1)%ARCH_EPOCH_INTERVAL)==0,
                out_loader=val_loader, out_optimizer=val_optimizer,
                out_criterion=cross_entropy, out_logit_adjust=None, out_params=[dy,ly,w_val],
                num_classes=num_classes,group_size=args.group_size,
                ARCH_EPOCH=ARCH_EPOCH,ARCH_INTERVAL=ARCH_INTERVAL,
                ARCH_TRAIN_SAMPLE=ARCH_TRAIN_SAMPLE,ARCH_VAL_SAMPLE=ARCH_VAL_SAMPLE,agumentation_list=augment_dict)
        logfile.write(str(dy)+str(ly)+'\n\n')
        dy_log.write(f'{dy.detach().cpu().numpy()}\n')
        ly_log.write(f'{ly.detach().cpu().numpy()}\n')
        acc_log.write(f'{train_acc} {val_acc} {test_acc}\n')
        logfile.flush()
        dy_log.flush()
        ly_log.flush()
        acc_log.flush()
        train_lr_scheduler.step()
        if i%args.checkpoint_interval==0:
                torch.save(model,f'{args.save_path}/epoch_{i}.pth')
logfile.close()
dy_log.close()
ly_log.close()
acc_log.close()
torch.save(model,f'{args.save_path}/loss_adjustment.pth')