import sys
import os
import argparse
import datetime

import numpy as np

import torch
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F

import time
import copy

import utils.utils as utils
import nns

import pandas as pd
import seaborn as sns

import random

def test_kernel(nets, bs = 100):
    with torch.no_grad():
        if type(nets) != list:
            nets = [nets]
        mse_list = []
        d_ = TensorDataset(train_x_agg, train_y_agg)
        dl_ = DataLoader(d_, batch_size = bs, shuffle = False, drop_last = False)
        d_2 = TensorDataset(test_x, test_y)
        dl_2 = DataLoader(d_2, batch_size = bs, shuffle = False, drop_last = False)
        for net in nets:
            k = []
            fy = []
            for x, y in dl_:
                feature = net(x.to(net.device), feature = True)['feature']
                feature = torch.cat([feature, torch.ones(len(x), 1).to(net.device)], dim = 1).to("cpu")
                k_ = feature.T @ feature
                ky = feature.T @ y.unsqueeze(1)
                k.append(k_.to("cpu"))
                fy.append(ky.to("cpu"))
            regressor = torch.linalg.pinv(torch.stack(k).sum(0)) @ torch.stack(fy).sum(0)
            mse = []
            for x, y in dl_2:
                feature = net(x.to(net.device), feature = True)['feature']
                feature = torch.cat([feature, torch.ones(len(x), 1).to(net.device)], dim = 1).to("cpu")
                output = (feature @ regressor)
                mse.append(loss_fn(output.reshape(-1), y).detach().cpu().numpy() * len(x))
            mse_list.append(np.array(mse).sum() / len(test_x))
    return np.array(mse_list).mean(), (np.array(mse_list) ** 0.5).mean()

def HSIC(K, L):
    assert K.shape[0] == K.shape[1]
    assert L.shape[0] == L.shape[1]
    assert K.shape[0] == L.shape[0]
    n = K.shape[0]
    H = (torch.eye(n) - torch.ones(n, n)/n).to(K.device)
    val = torch.trace(K @ H @ L @ H) / ((n-1) ** 2)
    if val <= 1e-12:
        val = torch.tensor(1e-12)
    return val

def CKA(K, L):
    assert K.shape == L.shape
    return HSIC(K, L) / ((HSIC(K, K) ** 0.5) * (HSIC(L, L) ** 0.5)) 

if __name__ == "__main__":

    global device, loss_fn, train_x_agg, train_y_agg, test_x, test_y, public_x, public_y

    time_ = datetime.datetime.now()
    name = f"_{time_.month}-{time_.day}-{time_.hour}-{time_.minute}-{time_.second}"
    
    arg_command = sys.argv[1:]
    parser = argparse.ArgumentParser()

    # general
    parser.add_argument("--cuda", type=int, default=0) # start gpu number to use
    parser.add_argument("--ngpu", type=int, default=1) # nb of gpus
    parser.add_argument("--ver", type=int, default=-1)
    parser.add_argument("--load", type=str, default=None)

    # dataset and settings
    parser.add_argument("--data", type=str, default='datasets/toy3_50_hetero')
    parser.add_argument("--nn_type", type=str, nargs="+", default=None)
    parser.add_argument("--nn_ratio", type=float, nargs="+", default=1.)

    # true if want to evaluate the concatenated kernel
    parser.add_argument("--cat", action="store_true")
    
    FLAGS, _ = parser.parse_known_args(arg_command)

    if FLAGS.load == None:
        utils.log_msg(logger, "No load.. You must load trained models" + "\n")
        exit()

    # data type and client numbers
    model_name = FLAGS.load.split('/')[-1]
    data_type = FLAGS.data.split('/')[1].split('_')[0]
    client_num = int(FLAGS.data.split('/')[1].split('_')[1])
    
    # logger
    utils.generate_dir('logs')
    FLAGS.log_fn = f"logs/log_test_{model_name}.txt"
    logger = utils.init_logger(FLAGS.log_fn)
    utils.log_arguments(logger, FLAGS)

    # device and version setting
    device = []
    for i in range(client_num):
        d_ = int((FLAGS.cuda + i % FLAGS.ngpu) % torch.cuda.device_count())
        device.append(f'cuda:{d_}')
    if FLAGS.ver == -1:
        ver = FLAGS.cuda
    else:
        ver = FLAGS.ver

    # loss function
    loss_fn = nn.MSELoss()#.to(device)

    # data load
    public_x, public_y = torch.load(FLAGS.data + '/0/train.pt')
    test_x, test_y = torch.load(FLAGS.data + '/0/test.pt')
    train_x = []
    train_y = []
    for i in range(1, client_num + 1):
        x_, y_ = torch.load(FLAGS.data + f'/{i}/train.pt')
        train_x.append(x_)
        train_y.append(y_)

    train_x_agg = torch.cat(train_x)
    train_y_agg = torch.cat(train_y)

    data_num_ratio = np.array([len(train_xp) for train_xp in train_x])
    data_num_ratio = data_num_ratio / np.sum(data_num_ratio)

    # network construction
    if FLAGS.nn_type == None:
        FLAGS.nn_ratio = [0.3, 0.3, 0.2, 0.2]
        if data_type == "toy3":
            FLAGS.nn_type = ["FNN4_32", "FNN4_64", "FNN5_32", "FNN3_64"]
        elif data_type == "energy":
            FLAGS.nn_type = ["FNN_ENERGY4_32", "FNN_ENERGY4_64", "FNN_ENERGY5_32", "FNN_ENERGY3_64"]
        elif data_type == "mnist":
            FLAGS.nn_type = ["ResNet18_MNIST", "ResNet34_MNIST", "MobileNetv2_MNIST", "ResNet50_MNIST"]
        elif data_type == "utk":
            FLAGS.nn_type = ["CNN1_UTK", "CNN2_UTK", "CNN3_UTK", "CNN4_UTK"]
        elif data_type == "imdb":
            FLAGS.nn_type = ["ResNet18_IMDB", "ResNet34_IMDB", "MobileNetv2_IMDB", "ResNet50_IMDB"]
    else:
        if type(FLAGS.nn_ratio) != list:
            FLAGS.nn_ratio = [FLAGS.nn_ratio]
        if type(FLAGS.nn_type) != list:
            FLAGS.nn_type = [FLAGS.nn_type]
    nn_num = [int(client_num * r) for r in FLAGS.nn_ratio]
    nets = []
    hidden_layer_num = 0
    for num, net in zip(nn_num, FLAGS.nn_type):
        for _ in range(num):
            if net[:10] == "FNN_ENERGY":
                num_layer = int(net[10:].split("_")[0])
                hidden_units = int(net[10:].split("_")[1])
                nets.append(nns.FNN(num_layer, hidden_units, data = "energy"))
                hidden_layer_num += hidden_units
            elif net[:3] == "FNN":
                num_layer = int(net[3:].split("_")[0])
                hidden_units = int(net[3:].split("_")[1])
                nets.append(nns.FNN(num_layer, hidden_units))
                hidden_layer_num += hidden_units
            elif net == "CNN1_UTK":
                nets.append(nns.CNN1_UTK())
                hidden_layer_num += 64
            elif net == "CNN2_UTK":
                nets.append(nns.CNN2_UTK())
                hidden_layer_num += 64
            elif net == "CNN3_UTK":
                nets.append(nns.CNN3_UTK())
                hidden_layer_num += 64
            elif net == "CNN4_UTK":
                nets.append(nns.CNN4_UTK())
                hidden_layer_num += 64
            elif net == "ResNet18_MNIST":
                nets.append(nns.ResNet18_MNIST())
                hidden_layer_num += 512
            elif net == "ResNet34_MNIST":
                nets.append(nns.ResNet34_MNIST())
                hidden_layer_num += 512
            elif net == "ResNet50_MNIST":
                nets.append(nns.ResNet50_MNIST())
                hidden_layer_num += 2048
            elif net == "MobileNetv2_MNIST":
                nets.append(nns.MobileNetv2_MNIST())
                hidden_layer_num += 1280
            elif net == "ResNet18_IMDB":
                nets.append(nns.ResNet18_IMDB())
                hidden_layer_num += 512
            elif net == "ResNet34_IMDB":
                nets.append(nns.ResNet34_IMDB())
                hidden_layer_num += 512
            elif net == "ResNet50_IMDB":
                nets.append(nns.ResNet50_IMDB())
                hidden_layer_num += 2048
            elif net == "MobileNetv2_IMDB":
                nets.append(nns.MobileNetv2_IMDB())
                hidden_layer_num += 1280

    for dr, net, d_ in zip(data_num_ratio, nets, device):
        net.data_ratio = dr
        net.device = d_

    # load pretrained model
    utils.log_msg(logger, "Load Pretrained Neural Networks.." + "\n")
    params = torch.load(FLAGS.load, map_location = torch.device('cpu'))
    for net, param in zip(nets, params):
        net.load_state_dict(param)
        net = net.to(net.device)
    net_concat = nns.net_agg(nets, hidden_layer_num)#.to(device)

    # test pretrained model
    with torch.no_grad():
        for net in nets:
            net.eval()
        pretrain_test = []
        dt = TensorDataset(test_x, test_y)
        dtl = DataLoader(dt, batch_size = 500, shuffle = False, drop_last = False)
        for net in nets:
            pretrain_test_ = 0
            for x, y in dtl:
                output = net(x.to(net.device))["output"].reshape(-1)
                pretrain_test_ += loss_fn(output, y.to(net.device)).detach().cpu().numpy() * len(x)
            pretrain_test_ = pretrain_test_ / len(test_x)
            pretrain_test.append(pretrain_test_)
    utils.log_msg(logger, f"Performance (Avg) : MSE {np.array(pretrain_test).mean()} RMSE {(np.array(pretrain_test) ** 0.5).mean()}")
    
    # effectiveness of individual feature kernel
    o1, o2 = test_kernel(nets)
    utils.log_msg(logger, f"Local Kernels Performance (Avg) : MSE {o1} RMSE {o2}")

    # effectiveness of kernel derived from feature concatenation
    if FLAGS.cat:
        o1, o2 = test_kernel(net_concat)
        utils.log_msg(logger, f"Kernel Concatenation Performance : MSE {o1} RMSE {o2}")

    # heatmap of pairwise CKAs
    kernels = []
    with torch.no_grad():
        for net in nets:
            o = []
            for x, y in dtl:
                o.append(net(x.to(net.device), feature = True)["feature"])
            o = torch.cat(o, dim = 0)
            k = (o @ o.T).to(f"cuda:{FLAGS.cuda}")
            kernels.append(k)

    cka = [[CKA(kernels[i], kernels[j]).detach().cpu().item() for i in range(50)] for j in range(50)]
    utils.log_msg(logger, f"Pairwise CKA :: Mean {np.array(cka).mean():.4f} Std {np.array(cka).std():.4f} Min {np.min(np.array(cka)):.4f}")
    cka_df = pd.DataFrame(cka)
    ax = sns.heatmap(cka_df, vmin=0.7, vmax=1, xticklabels=False, yticklabels=False)
    fig = ax.get_figure()
    utils.generate_dir('plt')
    fig.savefig(f'plt/heatmap_{model_name}.pdf', bbox_inches='tight')