import dgl
import numpy as np
import os
import socket
import time
import random
import glob
import argparse, json
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter
from tqdm import tqdm
import matplotlib
import matplotlib.pyplot as plt
torch.set_default_dtype(torch.float64)  # pre-process with double, train with float
from nets.OGBMOL_graph_classification.load_net import gnn_model
from data.data import LoadData


def gpu_setup(use_gpu, gpu_id):
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
    if torch.cuda.is_available() and use_gpu:
        print('cuda available with GPU:', torch.cuda.get_device_name(0))
        device = torch.device("cuda")
    else:
        print('cuda not available')
        device = torch.device("cpu")
    return device


def view_model_param(MODEL_NAME, net_params):
    model = gnn_model(MODEL_NAME, net_params)
    total_param = 0
    print("MODEL DETAILS:")
    # print(model)
    for param in model.parameters():
        # print(param.data.size())
        total_param += np.prod(list(param.data.size()))
    print('MODEL/Total parameters:', MODEL_NAME, total_param)
    return total_param


def train_val_pipeline(MODEL_NAME, dataset, params, net_params, dirs):
    t0 = time.time()
    per_epoch_time = []
    DATASET_NAME = dataset.name
    if net_params['pe_init'] == 'lap_pe':
        tt = time.time()
        print("[!] -LapPE: Initializing graph positional encoding with Laplacian PE.")
        dataset._add_lap_positional_encodings(net_params['pos_enc_dim'])
        print("[!] Time taken: ", time.time() - tt)
    elif net_params['pe_init'] == 'rand_walk':
        tt = time.time()
        print("[!] -LSPE: Initializing graph positional encoding with rand walk features.")
        dataset._init_positional_encodings(net_params['pos_enc_dim'], net_params['pe_init'])
        print("[!] Time taken: ", time.time() - tt)
        tt = time.time()
        print("[!] -LSPE (For viz later): Adding lapeigvecs to key 'eigvec' for every graph.")
        dataset._add_eig_vecs(net_params['pos_enc_dim'])
        print("[!] Time taken: ", time.time() - tt)
    elif net_params['pe_init'] == 'map':
        tt = time.time()
        print("[!] -MAP: Initializing graph positional encoding with MAP.")
        dataset._add_map_positional_encodings(net_params['pos_enc_dim'])
        print("[!] Time taken: ", time.time() - tt)
    elif net_params['pe_init'] == 'map_ablation':
        tt = time.time()
        print("[!] -MAP: Initializing graph positional encoding with partial MAP.")
        dataset._map_ablation(net_params['pos_enc_dim'], use_unique_sign=False, use_unique_basis=True, use_eig_val=True)
        print("[!] Time taken: ", time.time() - tt)
    if MODEL_NAME in ['SAN', 'GraphiT']:
        if net_params['full_graph']:
            st = time.time()
            print("[!] Adding full graph connectivity..")
            dataset._make_full_graph() if MODEL_NAME == 'SAN' else dataset._make_full_graph \
                ((net_params['p_steps'], net_params['gamma']))
            print('Time taken to add full graph connectivity: ', time.time() - st)
    trainset, valset, testset = dataset.train, dataset.val, dataset.test
    evaluator = dataset.evaluator
    root_log_dir, root_ckpt_dir, write_file_name, write_config_file, viz_dir = dirs
    device = net_params['device']
    # Write the network and optimization hyper-parameters in folder config/
    with open(write_config_file + '.txt', 'w') as f:
        f.write("""Dataset: {},\nModel: {}\n\nparams={}\n\nnet_params={}\n\n\nTotal Parameters: {}\n\n""".format
                (DATASET_NAME, MODEL_NAME, params, net_params, net_params['total_param']))
    log_dir = os.path.join(root_log_dir, "RUN_" + str(0))
    writer = SummaryWriter(log_dir=log_dir)
    # setting seeds
    random.seed(params['seed'])
    np.random.seed(params['seed'])
    torch.manual_seed(params['seed'])
    if device.type == 'cuda':
        torch.cuda.manual_seed(params['seed'])
        torch.cuda.manual_seed_all(params['seed'])
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    print("Training Graphs: ", len(trainset))
    print("Validation Graphs: ", len(valset))
    print("Test Graphs: ", len(testset))
    torch.set_default_dtype(torch.float)  # pre-process with double, train with float
    model = gnn_model(MODEL_NAME, net_params)
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=params['init_lr'], weight_decay=params['weight_decay'])
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
                                                     factor=params['lr_reduce_factor'],
                                                     patience=params['lr_schedule_patience'],
                                                     verbose=True)
    epoch_train_losses, epoch_val_losses = [], []
    epoch_train_accs, epoch_val_accs, epoch_test_accs = [], [], []
    # import train functions for all GNNs
    from train.train_OGBMOL_graph_classification import train_epoch_sparse as train_epoch, \
        evaluate_network_sparse as evaluate_network
    train_loader = DataLoader(trainset, num_workers=4, batch_size=params['batch_size'], shuffle=True,
                              collate_fn=dataset.collate, pin_memory=True)
    val_loader = DataLoader(valset, num_workers=4, batch_size=params['batch_size'], shuffle=False,
                            collate_fn=dataset.collate, pin_memory=True)
    test_loader = DataLoader(testset, num_workers=4, batch_size=params['batch_size'], shuffle=False,
                             collate_fn=dataset.collate, pin_memory=True)
    # At any point you can hit Ctrl + C to break out of training early.
    try:
        with tqdm(range(params['epochs'])) as t:
            for epoch in t:
                t.set_description('Epoch %d' % epoch)
                start = time.time()
                epoch_train_loss, epoch_train_acc, optimizer = train_epoch(model, optimizer, device, train_loader,
                                                                           epoch, evaluator)
                epoch_val_loss, epoch_val_acc, __ = evaluate_network(model, device, val_loader, epoch, evaluator)
                _, epoch_test_acc, __ = evaluate_network(model, device, test_loader, epoch, evaluator)
                del __
                epoch_train_losses.append(epoch_train_loss)
                epoch_val_losses.append(epoch_val_loss)
                epoch_train_accs.append(epoch_train_acc)
                epoch_val_accs.append(epoch_val_acc)
                epoch_test_accs.append(epoch_test_acc)
                writer.add_scalar('train/_loss', epoch_train_loss, epoch)
                writer.add_scalar('val/_loss', epoch_val_loss, epoch)
                writer.add_scalar('train/_avg_prec', epoch_train_acc, epoch)
                writer.add_scalar('val/_avg_prec', epoch_val_acc, epoch)
                writer.add_scalar('test/_avg_prec', epoch_test_acc, epoch)
                writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], epoch)
                if dataset.name in ["ogbg-moltox21", "ogbg-molhiv", "ogbg-moltoxcast"]:
                    t.set_postfix(time=time.time() - start, lr=optimizer.param_groups[0]['lr'],
                                  train_loss=epoch_train_loss, val_loss=epoch_val_loss,
                                  train_AUC=epoch_train_acc, val_AUC=epoch_val_acc,
                                  test_AUC=epoch_test_acc)
                elif dataset.name == "ogbg-molpcba":
                    t.set_postfix(time=time.time() - start, lr=optimizer.param_groups[0]['lr'],
                                  train_loss=epoch_train_loss, val_loss=epoch_val_loss,
                                  train_AP=epoch_train_acc, val_AP=epoch_val_acc,
                                  test_AP=epoch_test_acc)
                per_epoch_time.append(time.time() - start)
                # Saving checkpoint
                ckpt_dir = os.path.join(root_ckpt_dir, "RUN_")
                if not os.path.exists(ckpt_dir):
                    os.makedirs(ckpt_dir)
                torch.save(model.state_dict(), '{}.pkl'.format(ckpt_dir + "/epoch_" + str(epoch)))
                files = glob.glob(ckpt_dir + '/*.pkl')
                for file in files:
                    epoch_nb = file.split('_')[-1]
                    epoch_nb = int(epoch_nb.split('.')[0])
                    if epoch_nb < epoch - 1:
                        os.remove(file)
                scheduler.step(epoch_val_loss)
                if optimizer.param_groups[0]['lr'] < params['min_lr']:
                    print("!! LR EQUAL TO MIN LR SET.")
                    break
                # Stop training after params['max_time'] hours
                if time.time() - t0 > params['max_time'] * 3600:
                    print('-' * 89)
                    print("Max_time for training elapsed {:.2f} hours, so stopping".format(params['max_time']))
                    break
    except KeyboardInterrupt:
        print('-' * 89)
        print('Exiting from training early because of KeyboardInterrupt')
    # ___, __, g_outs_train = evaluate_network(model, device, train_loader, epoch, evaluator)
    ___, __, g_outs_test = evaluate_network(model, device, test_loader, epoch, evaluator)
    del ___
    del __
    # OGB: Test scores at best val epoch
    epoch_best = epoch_val_accs.index(max(epoch_val_accs))
    test_acc = epoch_test_accs[epoch_best]
    train_acc = epoch_train_accs[epoch_best]
    val_acc = epoch_val_accs[epoch_best]
    if dataset.name in ["ogbg-moltox21", "ogbg-molhiv", "ogbg-moltoxcast"]:
        print("Test AUC: {:.4f}".format(test_acc))
        print("Train AUC: {:.4f}".format(train_acc))
        print("Val AUC: {:.4f}".format(val_acc))
    elif dataset.name == "ogbg-molpcba":
        print("Test Avg Precision: {:.4f}".format(test_acc))
        print("Train Avg Precision: {:.4f}".format(train_acc))
    print("Convergence Time (Epochs): {:.4f}".format(epoch))
    print("TOTAL TIME TAKEN: {:.4f}s".format(time.time() - t0))
    print("AVG TIME PER EPOCH: {:.4f}s".format(np.mean(per_epoch_time)))
    if net_params['pe_init'] == 'rand_walk' and g_outs_test is not None:
        # Visualize actual and predicted/learned eigenvecs
        from utils.plot_util import plot_graph_eigvec
        if not os.path.exists(viz_dir):
            os.makedirs(viz_dir)
        sample_graph_ids = [153, 103, 123]
        for f_idx, graph_id in enumerate(sample_graph_ids):
            # Test graphs
            g_dgl = g_outs_test[graph_id]
            f = plt.figure(f_idx, figsize=(12, 6))
            plt1 = f.add_subplot(121)
            plot_graph_eigvec(plt1, graph_id, g_dgl, feature_key='eigvec', actual_eigvecs=True)
            plt2 = f.add_subplot(122)
            plot_graph_eigvec(plt2, graph_id, g_dgl, feature_key='p', predicted_eigvecs=True)
            f.savefig(viz_dir + '/test ' + str(graph_id) + '.jpg')
    writer.close()
    """
        Write the results in out_dir/results folder
    """
    if dataset.name in ["ogbg-moltox21", "ogbg-molhiv", "ogbg-moltoxcast"]:
        with open(write_file_name + '.txt', 'w') as f:
            f.write("""Dataset: {},\nModel: {}\n\nparams={}\n\nnet_params={}\n\n{}\n\nTotal Parameters: {}\n\n
        FINAL RESULTS\nTEST AUC: {:.4f}\nTRAIN AUC: {:.4f}\nVAL AUC: {:.4f}\n\n
        Convergence Time (Epochs): {:.4f}\nTotal Time Taken: {:.4f} hrs\nAverage Time Per Epoch: {:.4f} s\n\n\n""" \
                    .format(DATASET_NAME, MODEL_NAME, params, net_params, model, net_params['total_param'],
                            test_acc, train_acc, val_acc, epoch, (time.time() - t0) / 3600, np.mean(per_epoch_time)))
    elif dataset.name == "ogbg-molpcba":
        with open(write_file_name + '.txt', 'w') as f:
            f.write("""Dataset: {},\nModel: {}\n\nparams={}\n\nnet_params={}\n\n{}\n\nTotal Parameters: {}\n\n
        FINAL RESULTS\nTEST AVG PRECISION: {:.4f}\nTRAIN AVG PRECISION: {:.4f}\nVAL AVG PRECISION: {:.4f}\n\n
        Convergence Time (Epochs): {:.4f}\nTotal Time Taken: {:.4f} hrs\nAverage Time Per Epoch: {:.4f} s\n\n\n""" \
                    .format(DATASET_NAME, MODEL_NAME, params, net_params, model, net_params['total_param'],
                            test_acc, train_acc, val_acc, epoch, (time.time() - t0) / 3600, np.mean(per_epoch_time)))
    return test_acc


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', help="Please give a config.json file with training/model/data/param details")
    parser.add_argument('--gpu_id', help="Please give a value for gpu id")
    parser.add_argument('--model', help="Please give a value for model name")
    parser.add_argument('--dataset', help="Please give a value for dataset name")
    parser.add_argument('--out_dir', help="Please give a value for out_dir")
    parser.add_argument('--seed', help="Please give a value for seed")
    parser.add_argument('--epochs', help="Please give a value for epochs")
    parser.add_argument('--batch_size', help="Please give a value for batch_size")
    parser.add_argument('--init_lr', help="Please give a value for init_lr")
    parser.add_argument('--lr_reduce_factor', help="Please give a value for lr_reduce_factor")
    parser.add_argument('--lr_schedule_patience', help="Please give a value for lr_schedule_patience")
    parser.add_argument('--min_lr', help="Please give a value for min_lr")
    parser.add_argument('--weight_decay', help="Please give a value for weight_decay")
    parser.add_argument('--print_epoch_interval', help="Please give a value for print_epoch_interval")
    parser.add_argument('--L', help="Please give a value for L")
    parser.add_argument('--hidden_dim', help="Please give a value for hidden_dim")
    parser.add_argument('--out_dim', help="Please give a value for out_dim")
    parser.add_argument('--residual', help="Please give a value for residual")
    parser.add_argument('--edge_feat', help="Please give a value for edge_feat")
    parser.add_argument('--readout', help="Please give a value for readout")
    parser.add_argument('--kernel', help="Please give a value for kernel")
    parser.add_argument('--n_heads', help="Please give a value for n_heads")
    parser.add_argument('--gated', help="Please give a value for gated")
    parser.add_argument('--in_feat_dropout', help="Please give a value for in_feat_dropout")
    parser.add_argument('--dropout', help="Please give a value for dropout")
    parser.add_argument('--layer_norm', help="Please give a value for layer_norm")
    parser.add_argument('--batch_norm', help="Please give a value for batch_norm")
    parser.add_argument('--sage_aggregator', help="Please give a value for sage_aggregator")
    parser.add_argument('--data_mode', help="Please give a value for data_mode")
    parser.add_argument('--num_pool', help="Please give a value for num_pool")
    parser.add_argument('--gnn_per_block', help="Please give a value for gnn_per_block")
    parser.add_argument('--embedding_dim', help="Please give a value for embedding_dim")
    parser.add_argument('--pool_ratio', help="Please give a value for pool_ratio")
    parser.add_argument('--linkpred', help="Please give a value for linkpred")
    parser.add_argument('--cat', help="Please give a value for cat")
    parser.add_argument('--self_loop', help="Please give a value for self_loop")
    parser.add_argument('--max_time', help="Please give a value for max_time")
    parser.add_argument('--pos_enc_dim', help="Please give a value for pos_enc_dim")
    parser.add_argument('--alpha_loss', help="Please give a value for alpha_loss")
    parser.add_argument('--lambda_loss', help="Please give a value for lambda_loss")
    parser.add_argument('--pe_init', help="Please give a value for pe_init")
    parser.add_argument('--sign_inv_net', help="Please give a value for sign inv net")
    parser.add_argument('--sign_inv_layers', help="Please give a value for sign inv layers")
    parser.add_argument('--sign_inv_activation', help="Please give a value for sign inv activation function")
    parser.add_argument('--phi_out_dim', help="Please give a value for phi_out_dim")
    args = parser.parse_args()
    with open(args.config) as f:
        config = json.load(f)
    # device
    if args.gpu_id is not None:
        config['gpu']['id'] = int(args.gpu_id)
        config['gpu']['use'] = True
    device = gpu_setup(config['gpu']['use'], config['gpu']['id'])
    # model, dataset, out_dir
    if args.model is not None:
        MODEL_NAME = args.model
    else:
        MODEL_NAME = config['model']
    if args.dataset is not None:
        DATASET_NAME = args.dataset
    else:
        DATASET_NAME = config['dataset']
    dataset = LoadData(DATASET_NAME)
    if args.out_dir is not None:
        out_dir = args.out_dir
    else:
        out_dir = config['out_dir']
    # parameters
    params = config['params']
    if args.seed is not None:
        params['seed'] = int(args.seed)
    if args.epochs is not None:
        params['epochs'] = int(args.epochs)
    if args.batch_size is not None:
        params['batch_size'] = int(args.batch_size)
    if args.init_lr is not None:
        params['init_lr'] = float(args.init_lr)
    if args.lr_reduce_factor is not None:
        params['lr_reduce_factor'] = float(args.lr_reduce_factor)
    if args.lr_schedule_patience is not None:
        params['lr_schedule_patience'] = int(args.lr_schedule_patience)
    if args.min_lr is not None:
        params['min_lr'] = float(args.min_lr)
    if args.weight_decay is not None:
        params['weight_decay'] = float(args.weight_decay)
    if args.print_epoch_interval is not None:
        params['print_epoch_interval'] = int(args.print_epoch_interval)
    if args.max_time is not None:
        params['max_time'] = float(args.max_time)
    # network parameters
    net_params = config['net_params']
    net_params['device'] = device
    net_params['gpu_id'] = config['gpu']['id']
    net_params['batch_size'] = params['batch_size']
    if args.L is not None:
        net_params['L'] = int(args.L)
    if args.hidden_dim is not None:
        net_params['hidden_dim'] = int(args.hidden_dim)
    if args.out_dim is not None:
        net_params['out_dim'] = int(args.out_dim)
    if args.residual is not None:
        net_params['residual'] = True if args.residual == 'True' else False
    if args.edge_feat is not None:
        net_params['edge_feat'] = True if args.edge_feat == 'True' else False
    if args.readout is not None:
        net_params['readout'] = args.readout
    if args.kernel is not None:
        net_params['kernel'] = int(args.kernel)
    if args.n_heads is not None:
        net_params['n_heads'] = int(args.n_heads)
    if args.gated is not None:
        net_params['gated'] = True if args.gated == 'True' else False
    if args.in_feat_dropout is not None:
        net_params['in_feat_dropout'] = float(args.in_feat_dropout)
    if args.dropout is not None:
        net_params['dropout'] = float(args.dropout)
    if args.layer_norm is not None:
        net_params['layer_norm'] = True if args.layer_norm == 'True' else False
    if args.batch_norm is not None:
        net_params['batch_norm'] = True if args.batch_norm == 'True' else False
    if args.sage_aggregator is not None:
        net_params['sage_aggregator'] = args.sage_aggregator
    if args.data_mode is not None:
        net_params['data_mode'] = args.data_mode
    if args.num_pool is not None:
        net_params['num_pool'] = int(args.num_pool)
    if args.gnn_per_block is not None:
        net_params['gnn_per_block'] = int(args.gnn_per_block)
    if args.embedding_dim is not None:
        net_params['embedding_dim'] = int(args.embedding_dim)
    if args.pool_ratio is not None:
        net_params['pool_ratio'] = float(args.pool_ratio)
    if args.linkpred is not None:
        net_params['linkpred'] = True if args.linkpred == 'True' else False
    if args.cat is not None:
        net_params['cat'] = True if args.cat == 'True' else False
    if args.self_loop is not None:
        net_params['self_loop'] = True if args.self_loop == 'True' else False
    if args.pos_enc_dim is not None:
        net_params['pos_enc_dim'] = int(args.pos_enc_dim)
    if args.alpha_loss is not None:
        net_params['alpha_loss'] = float(args.alpha_loss)
    if args.lambda_loss is not None:
        net_params['lambda_loss'] = float(args.lambda_loss)
    if args.pe_init is not None:
        net_params['pe_init'] = args.pe_init
    if args.sign_inv_net is not None:
        net_params['sign_inv_net'] = args.sign_inv_net
    if args.sign_inv_layers is not None:
        net_params['sign_inv_layers'] = int(args.sign_inv_layers)
    if args.sign_inv_activation is not None:
        net_params['sign_inv_activation'] = args.sign_inv_activation
    if args.phi_out_dim is not None:
        net_params['phi_out_dim'] = args.phi_out_dim
    # OGBMOL*
    num_classes = dataset.dataset.num_tasks  # provided by OGB dataset class
    net_params['n_classes'] = num_classes
    if MODEL_NAME == 'PNA':
        D = torch.cat([torch.sparse.sum(g.adjacency_matrix(transpose=True), dim=-1).to_dense() for g, label in
                       dataset.train])
        net_params['avg_d'] = dict(lin=torch.mean(D),
                                   exp=torch.mean(torch.exp(torch.div(1, D)) - 1),
                                   log=torch.mean(torch.log(D + 1)))
    root_log_dir = out_dir + 'logs/' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str \
        (config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
    root_ckpt_dir = out_dir + 'checkpoints/' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str \
        (config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
    write_file_name = out_dir + 'results/result_' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str \
        (config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
    write_config_file = out_dir + 'configs/config_' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str \
        (config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
    viz_dir = out_dir + 'viz/' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str \
        (config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
    dirs = root_log_dir, root_ckpt_dir, write_file_name, write_config_file, viz_dir
    if not os.path.exists(out_dir + 'results'):
        os.makedirs(out_dir + 'results')
    if not os.path.exists(out_dir + 'configs'):
        os.makedirs(out_dir + 'configs')
    net_params['total_param'] = view_model_param(MODEL_NAME, net_params)
    train_val_pipeline(MODEL_NAME, dataset, params, net_params, dirs)


if __name__ == '__main__':
    main()
