import sys
import os
import argparse
import datetime

import numpy as np

import torch
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F

import time
import copy

import utils.utils as utils
import nns

if __name__ == "__main__":

    global device, loss_fn, train_x_agg, train_y_agg, test_x, test_y, public_x, public_y

    time_ = datetime.datetime.now()
    name = f"{time_.month}-{time_.day}-{time_.hour}-{time_.minute}-{time_.second}"
    
    arg_command = sys.argv[1:]
    parser = argparse.ArgumentParser()
    # general
    parser.add_argument("--cuda", type=int, default=0) # start gpu number to use
    parser.add_argument("--ngpu", type=int, default=1) # nb of gpus
    parser.add_argument("--ver", type=int, default=-1) # version
    parser.add_argument("--load", type=str, default=None) # model load
    parser.add_argument("--save", action="store_true")

    # dataset and settings
    parser.add_argument("--data", type=str, default='datasets/toy3_50_hetero') # dataset
    parser.add_argument("--nn_type", type=str, nargs="+", default=None) # neural network types
    parser.add_argument("--nn_ratio", type=float, nargs="+", default=1.) # neural network ratios

    # configuration for pretraining
    parser.add_argument("--optim_p", type=str, default="SGD") # optimizer for pretrain
    parser.add_argument("--lr_p", type=float, default=1e-2) # learning rate for pretrain
    parser.add_argument("--bs_p", type=int, default=10) # batch size for pretrain

    # configuration for collaborative learning
    parser.add_argument("--epochs", type=int, default=500) # communication rounds for FL
    parser.add_argument("--optim", type=str, default="Adam") # optimizer for FL
    parser.add_argument("--local_epochs", type=int, default=10) # local epochs for FL
    parser.add_argument("--distil_epochs", type=int, default=5) # distillation epochs for FL
    parser.add_argument("--lr", type=float, default=2e-4) # learning rate for FL
    parser.add_argument("--bs", type=int, default=10) # batch size for FL (local)
    parser.add_argument("--bs_public", type=int, default=32) # batch size for FL (public)
    parser.add_argument("--sample_public", type=int, default=500) # sample size of public data for each communication round
    parser.add_argument("--alpha", type=float, default=1) # distillation coefficient
    
    FLAGS, _ = parser.parse_known_args(arg_command)

    # data type and client numbers
    algorithm = 'FedMD'
    data_type = FLAGS.data.split('/')[1].split('_')[0]
    client_num = int(FLAGS.data.split('/')[1].split('_')[1])
    
    # logger
    utils.generate_dir('logs')
    utils.generate_dir('logs/baselines')
    FLAGS.log_fn = f'logs/baselines/{algorithm}_{data_type}_{name}.txt'
    logger = utils.init_logger(FLAGS.log_fn)
    utils.log_arguments(logger, FLAGS)

    # device and version setting
    device = []
    for i in range(client_num):
        d_ = int((FLAGS.cuda + i % FLAGS.ngpu) % torch.cuda.device_count())
        device.append(f'cuda:{d_}')
    if FLAGS.ver == -1:
        ver = FLAGS.cuda
    else:
        ver = FLAGS.ver

    # loss function
    loss_fn = nn.MSELoss()

    # data load
    public_x, public_y = torch.load(FLAGS.data + '/0/train.pt')
    test_x, test_y = torch.load(FLAGS.data + '/0/test.pt')
    train_x = []
    train_y = []
    for i in range(1, client_num + 1):
        x_, y_ = torch.load(FLAGS.data + f'/{i}/train.pt')
        train_x.append(x_)
        train_y.append(y_)
        
    train_x_agg = torch.cat(train_x)
    train_y_agg = torch.cat(train_y)
    
    data_num_ratio = np.array([len(train_xp) for train_xp in train_x])
    data_num_ratio = data_num_ratio / np.sum(data_num_ratio)

    # network construction
    if FLAGS.nn_type == None:
        FLAGS.nn_ratio = [0.3, 0.3, 0.2, 0.2]
        if data_type == "toy3":
            FLAGS.nn_type = ["FNN4_32", "FNN4_64", "FNN5_32", "FNN3_64"]
        elif data_type == "energy":
            FLAGS.nn_type = ["FNN_ENERGY4_32", "FNN_ENERGY4_64", "FNN_ENERGY5_32", "FNN_ENERGY3_64"]
        elif data_type == "mnist":
            FLAGS.nn_type = ["ResNet18_MNIST", "ResNet34_MNIST", "MobileNetv2_MNIST", "ResNet50_MNIST"]
        elif data_type == "utk":
            FLAGS.nn_type = ["CNN1_UTK", "CNN2_UTK", "CNN3_UTK", "CNN4_UTK"]
        elif data_type == "imdb":
            FLAGS.nn_type = ["ResNet18_IMDB", "ResNet34_IMDB", "MobileNetv2_IMDB", "ResNet50_IMDB"]
    else:
        if type(FLAGS.nn_ratio) != list:
            FLAGS.nn_ratio = [FLAGS.nn_ratio]
        if type(FLAGS.nn_type) != list:
            FLAGS.nn_type = [FLAGS.nn_type]
    nn_num = [int(client_num * r) for r in FLAGS.nn_ratio]
    nets = []
    hidden_layer_num = 0
    for num, net in zip(nn_num, FLAGS.nn_type):
        for _ in range(num):
            if net[:10] == "FNN_ENERGY":
                num_layer = int(net[10:].split("_")[0])
                hidden_units = int(net[10:].split("_")[1])
                nets.append(nns.FNN(num_layer, hidden_units, data = "energy"))
                hidden_layer_num += hidden_units
            elif net[:3] == "FNN":
                num_layer = int(net[3:].split("_")[0])
                hidden_units = int(net[3:].split("_")[1])
                nets.append(nns.FNN(num_layer, hidden_units))
                hidden_layer_num += hidden_units
            elif net == "CNN1_UTK":
                nets.append(nns.CNN1_UTK())
                hidden_layer_num += 64
            elif net == "CNN2_UTK":
                nets.append(nns.CNN2_UTK())
                hidden_layer_num += 64
            elif net == "CNN3_UTK":
                nets.append(nns.CNN3_UTK())
                hidden_layer_num += 64
            elif net == "CNN4_UTK":
                nets.append(nns.CNN4_UTK())
                hidden_layer_num += 64
            elif net == "ResNet18_MNIST":
                nets.append(nns.ResNet18_MNIST())
                hidden_layer_num += 512
            elif net == "ResNet34_MNIST":
                nets.append(nns.ResNet34_MNIST())
                hidden_layer_num += 512
            elif net == "ResNet50_MNIST":
                nets.append(nns.ResNet50_MNIST())
                hidden_layer_num += 2048
            elif net == "MobileNetv2_MNIST":
                nets.append(nns.MobileNetv2_MNIST())
                hidden_layer_num += 1280
            elif net == "ResNet18_IMDB":
                nets.append(nns.ResNet18_IMDB())
                hidden_layer_num += 512
            elif net == "ResNet34_IMDB":
                nets.append(nns.ResNet34_IMDB())
                hidden_layer_num += 512
            elif net == "ResNet50_IMDB":
                nets.append(nns.ResNet50_IMDB())
                hidden_layer_num += 2048
            elif net == "MobileNetv2_IMDB":
                nets.append(nns.MobileNetv2_IMDB())
                hidden_layer_num += 1280

    dt = TensorDataset(test_x, test_y)
    dlt = DataLoader(dt, batch_size = 1000, shuffle=False, drop_last=False)
    
    for dr, net, d_ in zip(data_num_ratio, nets, device):
        net.data_ratio = dr
        net.device = d_

    # pretraining phase
    if FLAGS.load == None:
        utils.log_msg(logger, "Pretrain local models.." + "\n")
        for net in nets:
            net = net.to(net.device)
    
        for net, train_xp, train_yp in zip(nets, train_x, train_y):
            # 5-folds cross-validation
            v_loss_list = []
            param_list = []
            d = TensorDataset(train_xp, train_yp)
            dl = DataLoader(d, batch_size = int(len(train_xp)/5), shuffle = True, drop_last = True)
            train_x_ = []
            train_y_ = []
            for x, y in dl:
                train_x_.append(x)
                train_y_.append(y)
            for r in range(5):
                train_x_t = torch.cat(train_x_[:r] + train_x_[(r+1):])
                train_y_t = torch.cat(train_y_[:r] + train_y_[(r+1):])
                train_x_v = train_x_[r]
                train_y_v = train_y_[r]
                
                net_candidate = copy.deepcopy(net)
                
                if FLAGS.optim_p == "SGD":
                    optimizer_candidate = torch.optim.SGD(params = net_candidate.parameters(), lr = FLAGS.lr_p, momentum = 0.9, weight_decay = 5e-4)
                elif FLAGS.optim_p == "Adam":
                    optimizer_candidate = torch.optim.Adam(params = net_candidate.parameters(), lr = FLAGS.lr_p, weight_decay = 5e-4)
                d = TensorDataset(train_x_t, train_y_t)
                dl = DataLoader(d, batch_size = FLAGS.bs_p, shuffle = True, drop_last = True)
                dv = TensorDataset(train_x_v, train_y_v)
                dlv = DataLoader(dv, batch_size = 500, shuffle = False, drop_last = False)
                
                best_v_loss = 100.
                worse_iter = 0
                while True:
                    net_candidate.train()
                    for x, y in dl:
                        optimizer_candidate.zero_grad()
                        x = x.to(net.device)
                        y = y.to(net.device)
                        output = net_candidate(x)['output']
                        loss = loss_fn(output.reshape(-1), y)
                        loss.backward()
                        optimizer_candidate.step()
                    net_candidate.eval()
                    v_loss = 0
                    with torch.no_grad():
                        for xv, yv in dlv:
                            v_loss += len(xv) * loss_fn(net_candidate(xv.to(net.device))['output'].reshape(-1).detach().cpu(), yv).numpy()
                    v_loss = v_loss / len(train_x_v)
                    if v_loss < best_v_loss:
                        param = copy.deepcopy(net_candidate.state_dict())
                        best_v_loss = v_loss
                        worse_iter = 0
                    else:
                        worse_iter += 1
                    if worse_iter >= 10:
                        v_loss_list.append(best_v_loss)
                        param_list.append(param)
                        break
            net.load_state_dict(param_list[np.argmin(np.array(v_loss_list))])
        if FLAGS.save:
            utils.generate_dir('model_save')
            saving_dir = 'model_save/' + FLAGS.data.split('/')[1]
            utils.generate_dir(saving_dir)
            torch.save([copy.deepcopy(net.state_dict()) for net in nets], saving_dir + f'/Pretrain_{data_type}_{client_num}_{ver}.pt')
    else:
        utils.log_msg(logger, "Load Neural Networks.." + "\n")
        params = torch.load(FLAGS.load, map_location = torch.device('cpu'))
        for net, param in zip(nets, params):
            net.load_state_dict(param)
            net = net.to(net.device)

    # test pretrained models
    MSE_list =  []
    with torch.no_grad():
        for net in nets:
            net.eval()
            mse = 0
            with torch.no_grad():
                for x, y in dlt:
                    mse += len(x) * loss_fn(net(x.to(net.device))['output'].reshape(-1).detach().cpu(), y).numpy()
            MSE_list.append(float(mse) / len(test_x))
            net.train()
    utils.log_msg(logger, f"Pretrained Local Model Performance (Avg) : MSE {np.array(MSE_list).mean()} RMSE {(np.array(MSE_list) ** 0.5).mean()}")

    # Federated Learning Procedure
    utils.log_msg(logger, "Start Federated Learning.." + "\n")
    for ep in range(1, FLAGS.epochs + 1):
        for net in nets:
            net.eval()

        # To compute consensus prediction
        with torch.no_grad():
            # random sample of public data
            d = TensorDataset(torch.arange(len(public_x)))
            dl = DataLoader(d, batch_size = FLAGS.sample_public, shuffle = True)
            for idx in dl:
                break
            idx = idx[0]
            ensemble = torch.zeros_like(public_y[idx])
            idx_u = int((FLAGS.sample_public - 1) / 1000)
            for net in nets:
                ensemble_ = []
                for j in range(idx_u + 1):
                    if j < idx_u:
                        ensemble_.append(net(public_x[idx][1000 * j:1000 * (j+1)].to(net.device))['output'].reshape(-1).cpu())
                    else:
                        ensemble_.append(net(public_x[idx][1000 * j:].to(net.device))['output'].reshape(-1).cpu())
                ensemble += net.data_ratio * torch.cat(ensemble_, dim = 0)
            dp = TensorDataset(public_x[idx], ensemble)
            dlp = DataLoader(dp, batch_size = FLAGS.bs_public, shuffle = True, drop_last = True)

        # distillation phase
        for i in range(len(nets)):
            if FLAGS.optim == "SGD":
                optimizer = torch.optim.SGD(params = nets[i].parameters(), lr = FLAGS.lr, momentum = 0.9, weight_decay = 5e-4)
            elif FLAGS.optim == "Adam":
                optimizer = torch.optim.Adam(params = nets[i].parameters(), lr = FLAGS.lr, weight_decay = 5e-4)
            nets[i].train()
            for ep_d in range(FLAGS.distil_epochs):
                for x, y in dlp:
                    optimizer.zero_grad()
                    x = x.to(nets[i].device)
                    y = y.to(nets[i].device)
                    loss = FLAGS.alpha * loss_fn(nets[i](x)["output"].reshape(-1), y)
                    loss.backward()
                    optimizer.step()
            nets[i].eval()

        # test phase
        if ep % 10 == 0:
            MSE_list =  []
            for net in nets:
                net.eval()
                mse = 0
                with torch.no_grad():
                    for x, y in dlt:
                        mse += len(x) * loss_fn(net(x.to(net.device))['output'].reshape(-1).detach().cpu(), y).numpy()
                MSE_list.append(float(mse) / len(test_x))
                net.train()
            utils.log_msg(logger, f"Epoch {ep} Test Loss : MSE {np.array(MSE_list).mean()} RMSE {(np.array(MSE_list) ** 0.5).mean()}")
        if ep % 10 == 0:
            utils.generate_dir('model_save')
            saving_dir = 'model_save/' + FLAGS.data.split('/')[1]
            utils.generate_dir(saving_dir)
            torch.save([copy.deepcopy(net.state_dict()) for net in nets], saving_dir + f'/{algorithm}_{data_type}_{client_num}_{ver}_{ep}.pt')
            
        # local training phase
        for i in range(len(nets)):
            nets[i].to(nets[i].device)
            if FLAGS.optim == "SGD":
                optimizer = torch.optim.SGD(params = nets[i].parameters(), lr = FLAGS.lr, momentum = 0.9, weight_decay = 5e-4)
            elif FLAGS.optim == "Adam":
                optimizer = torch.optim.Adam(params = nets[i].parameters(), lr = FLAGS.lr, weight_decay = 5e-4)
            d = TensorDataset(train_x[i], train_y[i])
            dl = DataLoader(d, batch_size = FLAGS.bs, shuffle = True, drop_last = True)
            nets[i].train()
            for ep_l in range(FLAGS.local_epochs):
                for x, y in dl:
                    optimizer.zero_grad()
                    x = x.to(nets[i].device)
                    y = y.to(nets[i].device)
                    loss = loss_fn(nets[i](x)["output"].reshape(-1), y)
                    loss.backward()
                    optimizer.step()