import torch
import random
import numpy as np
import os
import pickle
import warnings
from config import args
from hyperbolic_data_loader import hy_load_data_from_dgl_dataset_class
from euclidean_data_loader import eu_load_data_from_dgl_dataset_class
from utils.get_model_dict import get_model_dict
from utils.other_utils import load_checkpoint
from GKDonMS import graph_knowledge_distillation_on_manifold_structure



def main(args):
    global dataset_name
    args.device = 'cuda:' + str(args.gpu) if int(args.gpu) >= 0 else 'cpu'
    device = args.device

    # Data format for hyperbolic models
    hyp_data = hy_load_data_from_dgl_dataset_class(args)
    for x, val in hyp_data.items():
        if torch.is_tensor(hyp_data[x]):
            hyp_data[x] = hyp_data[x].to(args.device)
    # Data format for Euclidean models
    euc_g = eu_load_data_from_dgl_dataset_class(args)
    euc_g = euc_g.to(device)

    args.n_classes = int(torch.max(euc_g.ndata['label']) + 1)
    args.n_nodes, args.feat_dim = hyp_data['features'].shape
    # Creating models according args.
    model_dict = get_model_dict(args, euc_g)

    if args.dataset in ['airport', 'disease', 'cora', 'citeseer', 'wikics', 'amz']:
        if args.dataset == "airport":
            dataset_name = "Airport"
        if args.dataset == "disease":
            dataset_name = "Disease"
        if args.dataset == "cora":
            dataset_name = "Cora"
        if args.dataset == "citeseer":
            dataset_name = "Citeseer"
        if args.dataset == "wikics":
            dataset_name = "WikiCS"
        if args.dataset == "amz":
            dataset_name = "AmzCBC"
    else:
        print("Please select one of the six supported datasets.")
        return

    if os.path.isfile("./saved_models/" + dataset_name + "/et_model.pt"):
        print("\033[33m", "Load Euclidean teacher model", "\033[0m")
        load_checkpoint(model_dict['et_model']['model'], "./saved_models/" + dataset_name + "/et_model.pt", device)
    else:
        print("Please put Euclidean teacher model in right path")
        return

    if os.path.isfile("./saved_models/" + dataset_name + "/ht_model.pt"):
        print("\033[33m", "Load hyperbolic teacher model", "\033[0m")
        load_checkpoint(model_dict['ht_model']['model'], "./saved_models/" + dataset_name + "/ht_model.pt", device)
    else:
        print("Please put hyperbolic teacher model in right path")
        return

    if os.path.isfile("./saved_models/" + dataset_name + "/GEO.pt"):
        print("\033[33m", "Load GEO model", "\033[0m")
        load_checkpoint(model_dict['GEO']['model'], "./saved_models/" + dataset_name + "/GEO.pt", device)
    else:
        print("Please put GEO model in right path")
        return

    with open("./saved_models/" + dataset_name + "/node_delta_hyperbolicity.pkl", 'rb') as file:
        model_dict['node_delta_hyperbolicity'] = pickle.load(file)

    graph_knowledge_distillation_on_manifold_structure(args, model_dict, hyp_data, euc_g)


if __name__ == '__main__':
    warnings.filterwarnings("ignore")
    # #### Fixing seed so the experiment can be reproduced ###
    seed = args.seed
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.enabled = True
    os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
    torch.use_deterministic_algorithms(True)

    main(args)

