import argparse
import os
import torch
from torch import nn, optim
from torchvision import datasets, transforms
from time import time

from core.new_fid import calculate_fid
from core import metrics, dataloader, utils, samplers
from scripts.vae import models
from core.inception_score_script import *

import pandas as pd
import pickle as pkl

parser = argparse.ArgumentParser()
parser.add_argument("--nocuda", action='store_true', default=False)
parser.add_argument("--checkpoints", type=str, default='')
parser.add_argument("--dataset", type=str, default='cifar')
parser.add_argument("--ntries", type=int, default=1)
parser.add_argument("--act_path", type=str, default='')
flags = parser.parse_args()


def get_discriminator(checkpoint_name, device):
    if flags.dataset == 'cifar':
        discriminator = models.DiscriminatorCifar()
    elif flags.dataset == 'celeba':
        discriminator = models.DiscriminatorCeleba()
    else:
        raise NotImplementedError
    if checkpoint_name.find('upper-bound') > -1:
        return models.DiscriminatorUB(discriminator).to(device)
    elif checkpoint_name.find('cross-ent') > -1:
        return models.DiscriminatorCCE(discriminator).to(device)
    elif checkpoint_name.find('markov-ent') > -1:
        return models.DiscriminatorMCE(discriminator).to(device)
    else:
        raise NotImplementedError


def get_scores(generator, discriminator, real_images, device):
    init_samples = real_images.clone().to(device)
    samples, accepts = samplers.disc_MH(generator, discriminator, int(1e4), init_samples, generator.h_dim)
    gen_samples = list(map(lambda s: torch.stack(s[1:], 0), filter(lambda l: len(l) > 1, samples)))
    gen_samples = torch.cat(gen_samples, 0)
    accepts = np.concatenate(list(map(lambda s: np.array(s), accepts)), 0)
    ar = len(accepts) / np.sum(accepts)
    gen_samples = (gen_samples - 0.5)/0.5
    genset = torch.utils.data.TensorDataset(gen_samples)
    genloader = torch.utils.data.DataLoader(genset, batch_size=200, shuffle=True, num_workers=4)
    fid = calculate_fid([flags.act_path, ''], ['', genloader], device, full=True, silent=True)
    is_mh = inception_score(genset, device, resize=True, splits=1)[0]
    return fid, is_mh, ar


def main():
    with open(flags.checkpoints, 'r') as thefile:
        checkpoints = thefile.readlines()
    checkpoints = list(map(lambda s: s.strip(), checkpoints))
    logs_dir = os.path.dirname(checkpoints[0])

    device = torch.device("cpu") if flags.nocuda else torch.device("cuda:0")
    if flags.dataset == 'cifar':
        dataroot = "../../data/cifar"
        image_size = 32
        trainset = datasets.CIFAR10(dataroot, train=True, download=True,
                                    transform=transforms.Compose([transforms.ToTensor()]))
        get_decoder = lambda dev: models.DecoderCifar(128).to(dev)
    elif flags.dataset == 'celeba':
        dataroot = "../../data/celeba"
        image_size = 64
        trainset = datasets.ImageFolder(root=dataroot,
                                        transform=transforms.Compose([
                                            transforms.Resize(image_size),
                                            transforms.CenterCrop(image_size),
                                            transforms.ToTensor()
                                        ]))
        get_decoder = lambda dev: models.DecoderCeleba(64).to(dev)
    else:
        raise NotImplementedError('unrecognized dataset')

    batch_size = 1000
    im_true = torch.zeros((flags.ntries, batch_size, 3, image_size, image_size))
    for n in range(flags.ntries):
        for i in range(batch_size):
            im_true[n,i,:,:,:] = trainset[n*batch_size+i][0].view([1, 3, image_size, image_size])
    filename = os.path.join(logs_dir, '%s-scores.pkl' % flags.dataset)
    for i in range(len(checkpoints)):
        decoder_dict, discriminator_dict = torch.load(checkpoints[i], map_location='cpu')
        decoder = get_decoder(device)
        discriminator = get_discriminator(checkpoints[i], device)
        decoder.load_state_dict(decoder_dict)
        discriminator.load_state_dict(discriminator_dict)
        scores = []
        for n in range(flags.ntries):
            start = time()
            fid_triplet, is_mh, emp_ar = get_scores(decoder, discriminator, im_true[n, :, :, :, :], device)
            scores.append([checkpoints[i], fid_triplet, is_mh, emp_ar])
            print('i:%d n:%d, %.2f' % (i, n, time()-start))
        if os.path.exists(filename):
            with open(filename, 'rb') as thefile:
                old_scores = pkl.load(thefile)
                scores.extend(old_scores)
        with open(filename, 'wb') as thefile:
            pkl.dump(scores, thefile)
    return


if __name__ == '__main__':
    main()
