
import os
import torch
import argparse
import itertools
import numpy as np
# from unet import Unet
from tqdm import tqdm
import torch.optim as optim
from cfg.diffusion import GaussianDiffusion
from torchvision.utils import save_image
from cfg.utils import get_named_beta_schedule
from cfg.embedding import ConditionalEmbedding, MNISTEmbedding
from cfg.Scheduler import GradualWarmupScheduler

import sys; sys.path.append('../retrain_trick'); sys.path.append('../Morpho-MNIST')
#from dataloader_cifar import load_data, transback
# from gen_retrain_trick import load_data, transback, RetrainTrickDataset
from cfg.dataloader_pickle import PickleDataset, transback, load_data
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import get_rank, init_process_group, destroy_process_group, all_gather, get_world_size

import torch.nn.functional as F

from path_constant import project_root

from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from cfg.unet import Unet
import copy
from cfg.embedding import JointEmbedding2, JointConditionalEmbedding
from torchvision import transforms

from napkin_mnist.train_classifiers import load_classifiers


def load_data(dataset: PickleDataset, batchsize: int) -> tuple[DataLoader, DistributedSampler]:
    trainloader = DataLoader(dataset,
                             batch_size=batchsize,
                             shuffle=True,
                             drop_last=True)
    return trainloader



def cycler(loader):
    while True:
        for batch in loader:
            yield batch



def get_parent_embedding(datakey, cemblayer, batch):

# W1->W2a
# W1->W2b
# W2a ->X <- W2b
# X -> Y

    cemb=None
    if datakey=="X":
        cemb = cemblayer(batch['W2a'].to(device),batch['W2b'].to(device))
        cemb = F.dropout1d(cemb, params.threshold)

    elif datakey=="Y":
        lab = batch['X'].to(device)
        cemb = cemblayer(lab)
        cemb = F.dropout1d(cemb, params.threshold)

    return cemb


def sample(datakey, diffusion, cemblayer, batch_size, parent_batch):

    diffusion.model.eval()
    if cemblayer!=None:
        cemblayer.eval()

    with torch.no_grad():
        cemb= get_parent_embedding(datakey, cemblayer, parent_batch)

        genshape = (batch_size , 3, 32, 32)
        if params.ddim:
            generated = diffusion.ddim_sample(genshape, params.num_steps, params.eta, params.select, cemb = cemb)
        else:
            generated = diffusion.sample(genshape, cemb = cemb)

        # cond = transback(cond)
        img = transback(generated)

        final_imgs = torch.cat([img], dim=1) #(b, 9, 32, 32)   #user
        final_imgs = final_imgs.reshape(-1, 3, 32, 32).contiguous()

    return final_imgs


def train(net, cemblayer, params,moddir, samdir, true_classifiers):
    # load last epoch
    lastpath = os.path.join(moddir,f'last_epoch.pt')
    if os.path.exists(lastpath):
        lastepc = torch.load(lastpath)['last_epoch']
        # load checkpoints
        checkpoint = torch.load(os.path.join(moddir, f'ckpt_{lastepc}_checkpoint.pt'), map_location='cpu')
        net.load_state_dict(checkpoint['net'])
        if cemblayer!= None:
            cemblayer.load_state_dict(checkpoint['cemblayer'])
    else:
        lastepc = 0
    betas = get_named_beta_schedule(num_diffusion_timesteps = params.T)
    diffusion = GaussianDiffusion(
                    dtype = params.dtype,
                    model = net,
                    betas = betas,
                    w = params.w,
                    v = params.v,
                    device = device
                )

    # optimizer settings

    model_params=list(diffusion.model.parameters())
    if cemblayer!= None:
        model_params+= list(cemblayer.parameters())
    optimizer = torch.optim.AdamW(
                    model_params,
                    lr = params.lr,
                    weight_decay = 1e-4
                )

    cosineScheduler = optim.lr_scheduler.CosineAnnealingLR(
                            optimizer = optimizer,
                            T_max = params.epoch,
                            eta_min = 0,
                            last_epoch = -1
                        )
    warmUpScheduler = GradualWarmupScheduler(
                            optimizer = optimizer,
                            multiplier = params.multiplier,
                            warm_epoch = params.epoch // 10,
                            after_scheduler = cosineScheduler,
                            last_epoch = lastepc
                        )
    if lastepc != 0:
        optimizer.load_state_dict(checkpoint['optimizer'])
        warmUpScheduler.load_state_dict(checkpoint['scheduler'])





     # training
    cnt = torch.cuda.device_count()
    for epc in range(lastepc, params.epoch):
        # turn into train mode
        diffusion.model.train()
        if cemblayer!=None:
            cemblayer.train()
        # sampler.set_epoch(epc)
        # batch iterations
        # with tqdm(dataloader, dynamic_ncols=True, disable=(local_rank % cnt != 0)) as tqdmDataLoader:

        intv_batch={lb:[] for lb in ['W1', 'W2a','W2b', 'X','Y']}
        iter=0
        with tqdm(dataloader, dynamic_ncols=True) as tqdmDataLoader:
            for batch in tqdmDataLoader:
                optimizer.zero_grad()
                x_0 = batch[params.datakey].to(device)

                cemb= get_parent_embedding(params.datakey, cemblayer, batch)

                loss = diffusion.trainloss(x_0, cemb = cemb)
                loss.backward()
                optimizer.step()
                iter+=1

                # if iter==20:
                    # break
                #Collecting intervention
                truidx = torch.where((batch['W2b'] == 0) & (batch['W2a'] == 3))
                for key in intv_batch:
                    intv_batch[key].append(batch[key][truidx])


                tqdmDataLoader.set_postfix(
                    ordered_dict={
                        "epoch": epc + 1,
                        "loss: ": loss.item(),
                        "batch per device: ":x_0.shape[0],
                        "img shape: ": x_0.shape[1:],
                        "LR": optimizer.state_dict()['param_groups'][0]["lr"]
                    }
                )
        warmUpScheduler.step()
        # evaluation and save checkpoint
        if (epc + 1) % 5 == 0:
            os.makedirs(moddir, exist_ok=True)
            os.makedirs(samdir, exist_ok=True)


            # generating samples
            # Generates genbatch pictures in 2 columns
            # column 0: conditioning image
            # column 1: generated image

            each_device_batch = params.genbatch // cnt
            val_batch = next(val_cycler)

            # evaluate do(X=3, red)
            val_batch= {lb: torch.cat(intv_batch[lb], dim=0)[0:200] for lb in intv_batch.keys()}
            each_device_batch= val_batch['X'].shape[0]

            final_imgs= sample(params.datakey, diffusion, cemblayer, each_device_batch, val_batch)

            output = true_classifiers['Y_color'](val_batch['Y'].to(device))
            bins = torch.bincount(torch.argmax(output, dim=1))
            print('---> True color P(Y|X) from classifier', bins / sum(bins))
            colprob=  bins / sum(bins)



            output = true_classifiers['Y_color'](final_imgs)
            bins = torch.bincount(torch.argmax(output, dim=1))
            print('---> Pred color P(Y|X) from classifier', bins / sum(bins))


            output = true_classifiers['Y_digit'](val_batch['Y'].to(device))
            bins = torch.bincount(torch.argmax(output, dim=1))
            print('xxx-> True digit P(Y|X) from classifier', bins / sum(bins))

            output = true_classifiers['Y_digit'](final_imgs)
            bins = torch.bincount(torch.argmax(output, dim=1))
            print('xxx-> Pred digit P(Y|X) from classifier', bins / sum(bins))


            # save_image(final_imgs, os.path.join(samdir, f'generated_{epc+1}_pict.png'), nrow = 3)
            # print('Image saved as ',os.path.join(samdir, f'generated_{epc+1}_pict.png'))


            # if (epc + 1) % params.interval == 0 or (colprob[3]<0.10 and colprob[4]<0.10 and colprob[4]<0.10):
            if (colprob[3]<0.10 and colprob[4]<0.10 and colprob[4]<0.10):
                # save checkpoints
                checkpoint = {
                                    'net':diffusion.model.state_dict(),
                                    'optimizer':optimizer.state_dict(),
                                    'scheduler':warmUpScheduler.state_dict()
                                }

                if cemblayer!=None:
                    checkpoint['cemblayer']=cemblayer.state_dict()

                torch.save({'last_epoch':epc+1}, os.path.join(moddir,f'last_epoch.pt'))
                torch.save(checkpoint, os.path.join(moddir, f'ckpt_{epc+1}_checkpoint.pt'))

        torch.cuda.empty_cache()


    return diffusion, cemblayer




if __name__ == '__main__':


    # several hyperparameters for model
    parser = argparse.ArgumentParser(description='test for diffusion model')
    parser.add_argument('--train_pkl', type=str, default=f"{project_root}/napkin_mnist/base_data/napkin_mnist_train.pkl")
    parser.add_argument('--val_pkl', type=str, default=f"{project_root}/napkin_mnist/base_data/napkin_mnist_train.pkl")
    parser.add_argument('--datakey', type=str, help='which of the data keys is the one we want to generate')
    parser.add_argument('--condkey', type=str, help='which of the data keys is the one we use for conditioning')
    parser.add_argument('--batchsize',type=int,default=256,help='batch size per device for training Unet model')
    parser.add_argument('--numworkers',type=int,default=4,help='num workers for training Unet model')
    parser.add_argument('--inch',type=int,default=3,help='input channels for Unet model')
    parser.add_argument('--modch',type=int,default=64,help='model channels for Unet model')
    parser.add_argument('--T',type=int,default=1000,help='timesteps for Unet model')
    parser.add_argument('--outch',type=int,default=3,help='output channels for Unet model')
    parser.add_argument('--chmul',type=list,default=[1,2,2,2],help='architecture parameters training Unet model')
    parser.add_argument('--numres',type=int,default=2,help='number of resblocks for each block in Unet model')
    parser.add_argument('--cdim',type=int,default=64,help='dimension of conditional embedding')
    parser.add_argument('--useconv',type=bool,default=True,help='whether use convlution in downsample')
    parser.add_argument('--droprate',type=float,default=0.1,help='dropout rate for model')
    parser.add_argument('--dtype',default=torch.float32)
    parser.add_argument('--lr',type=float,default=2e-4,help='learning rate')
    parser.add_argument('--w',type=float,default=1.8,help='hyperparameters for classifier-free guidance strength')
    parser.add_argument('--v',type=float,default=0.3,help='hyperparameters for the variance of posterior distribution')
    parser.add_argument('--epoch',type=int,default=400,help='epochs for training')
    parser.add_argument('--multiplier',type=float,default=2.5,help='multiplier for warmup')
    parser.add_argument('--threshold',type=float,default=0.1,help='threshold for classifier-free guidance')
    parser.add_argument('--interval',type=int,default=50,help='epoch interval between two evaluations')
    parser.add_argument('--moddir',type=str,default=f'{project_root}/Baselines/DiffusionBasedCausalModels/imgcond_model',help='model addresses')
    parser.add_argument('--samdir',type=str,default=f'{project_root}/Baselines/DiffusionBasedCausalModels/imgcond_sample',help='sample addresses')
    parser.add_argument('--genbatch',type=int,default=80,help='batch size for sampling process')
    # parser.add_argument('--clsnum',type=int,default=1000,help='num of label classes')
    parser.add_argument('--num_steps',type=int,default=50,help='sampling steps for DDIM')
    parser.add_argument('--eta',type=float,default=0,help='eta for variance during DDIM sampling process')
    parser.add_argument('--select',type=str,default='linear',help='selection stragies for DDIM')
    parser.add_argument('--ddim',type=lambda x:(str(x).lower() in ['true','1', 'yes']),default=True,help='whether to use ddim')  #default was false
    parser.add_argument('--local_rank',default=-1,type=int,help='node rank for distributed training')

    # args = parser.parse_args()

    params, unknown = parser.parse_known_args()


    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')



    # load data
    train_data = PickleDataset(params.train_pkl)
    val_data = PickleDataset(params.val_pkl)
    dataloader = load_data(train_data, params.batchsize)
    val_loader = load_data(val_data, params.genbatch // torch.cuda.device_count())


    for batch in dataloader:
        print(batch['W1'].shape)
        break



    val_cycler = cycler(val_loader)

    # initialize models
    net={}
    for lb in ['W1', 'X', 'Y']:

        use_cemb= True
        if lb=="W1":
            use_cemb=False

        net[lb] = Unet(
                    in_ch = params.inch,  # here it is 3 for W1,X,Y
                    mod_ch = params.modch,
                    out_ch = params.outch, # here it is 3 for W1,X,Y
                    ch_mul = params.chmul,
                    num_res_blocks = params.numres,
                    cdim = params.cdim,
                    use_conv = params.useconv,
                    droprate = params.droprate,
                    dtype = params.dtype,
                    use_cemb= use_cemb
                )


    # Load classifiers for evaluation
    save_dir = f'{project_root}/napkin_mnist/saved_classifier_models/'
    save_name = 'napkin_classifer'
    true_classifiers = load_classifiers(save_dir, save_name)
    true_classifiers['Y_color'] = true_classifiers['Y_color'].to(device)
    true_classifiers['Y_digit'] = true_classifiers['Y_digit'].to(device)
    true_classifiers.keys()



    cemblayer={}

    cemblayer['W1']= None # No parent

    # X is taking W2a,W2b as input.
    cemblayer['X'] = JointConditionalEmbedding(num_labels_0=10, num_labels_1=6,
                               d_model=params.cdim,
                               dim=params.cdim).to(device)

    # Y is taking image X as input.
    cemblayer['Y'] = MNISTEmbedding(3, params.cdim, hw=32).to(device)





    # for lb in ['W1','X', 'Y']:
    for lb in ['Y']:
        moddir= os.path.join(params.moddir, lb)
        samdir= os.path.join(params.samdir, lb)
        params.datakey=lb
        print(moddir, samdir)
        train(net[lb], cemblayer[lb],  params, moddir, samdir, true_classifiers)


# # export PYTHONPATH="${PYTHONPATH}:/root/PycharmProjects/IDGEN/Baselines/DiffusionBasedCausalModels/napkin"