import torch
import sys
import os
import wandb
from datetime import datetime
import torch.optim as optim
import transformers
import pickle
import time
from tqdm import tqdm
import shutil
from pathlib import Path
from utils.wandb_history import get_run_prefix
from datetime import datetime

sys.path.extend(os.path.join(os.path.dirname(__file__), "../../"))


from utils.utils_batching_preproc import obtain_supernode_embeddings
from utils.utils_batching_preproc import create_batch_from_task_list, obtain_supernode_embeddings
from models.gnn_with_edge_attr import BipartiteMsgPassingGNN, BipartiteGAT
from models.multilayer_gnn import MultiLayerGNN, MultiLayerBipartiteGNN
from models.metaGNN import MetaGNN
from models.get_model import get_model, print_num_trainable_params
from models.model_eval_utils import accuracy
from models.general_gnn import SingleLayerGeneralGNN
from models.sentence_embedding import SentenceEmb
from utils.utils_batching_preproc import preprocess_task
from models.gnn_with_edge_attr import GNNWithSupernodePooling
from data.arxiv import get_arxiv_dataloader
from utils.collater import get_dataset_dataloader


class TrainerFS():
    def __init__(self, dataset, parameter):
        wandb.init(project="graph-clip", name = parameter["prefix"] + "_" + datetime.now().strftime("%d_%m_%Y_%H_%M_%S") )
        print("---------Parameters---------")
        for k, v in parameter.items():
            print(k + ': ' + str(v))
        print("----------------------------")
        wandb.config.trainer_fs = True

        self.parameter = parameter

        self.ignore_label_embeddings = parameter['ignore_label_embeddings']
        self.is_zero_shot = parameter['zero_shot']

        # parameters
        self.batch_size = parameter['batch_size']
        self.learning_rate = parameter['learning_rate']
        self.dataset_len_cap = parameter['dataset_len_cap']
        self.invalidate_cache = parameter['invalidate_cache']
        self.early_stopping_patience = parameter['early_stopping_patience']

        # epoch
        self.epoch = parameter["epochs"]
        self.print_epoch = parameter['print_epoch']
        self.eval_epoch = parameter['eval_epoch']
        self.checkpoint_epoch = parameter['checkpoint_epoch']

        self.dataset_name = parameter['dataset']
        self.classification_only = self.parameter["classification_only"]
        # classification_only = True means that we only train a simple classifier (no text dot product at the end)

        self.shots = parameter['n_shots']  # k shots!

        self.device = parameter['device']

        self.loss = torch.nn.CrossEntropyLoss()

        self.cos = torch.nn.CosineSimilarity(dim=1)

        bert_dim = 768

        self.emb_dim = parameter["emb_dim"]
        self.gnn_type = parameter["gnn_type"]
        self.original_features = parameter["original_features"]

        self.parameter['prefix'] = self.parameter['prefix'] + "_" + datetime.now().strftime("%d_%m_%Y_%H_%M_%S")
        self.fix_datasets = self.parameter['fix_datasets_first']

        self.gnn_module = get_model(add_to_dim_in=0, emb_dim=self.parameter["emb_dim"],
                                    n_layer=self.parameter["n_layer"],
                                    bert_dim=768, input_dim=self.parameter["input_dim"],
                                    classification_only=True,
                                    gnn_type=parameter["gnn_type"])

        self.gnn_module = GNNWithSupernodePooling(self.gnn_module, obtain_supernode_embeddings)

        self.second_gnn_type = self.parameter["second_gnn"]
        if self.second_gnn_type == "vanilla":
            self.metagraph_gnn = BipartiteMsgPassingGNN(edge_attr_dim=2, emb_dim=self.emb_dim)
        elif self.second_gnn_type == "Atten":
            self.metagraph_gnn = MetaGNN(edge_attr_dim=2, emb_dim=self.emb_dim, n_layers = self.parameter["meta_n_layer"])
        elif self.second_gnn_type == "gat":
            self.metagraph_gnn = BipartiteGAT(edge_attr_dim=2, emb_dim=self.emb_dim)
        elif self.second_gnn_type == "gat3":
            # 3-layer GAT
            self.metagraph_gnn = MultiLayerBipartiteGNN(
                module_list=torch.nn.ModuleList([BipartiteGAT(edge_attr_dim=2, emb_dim=self.emb_dim) for _ in range(3)]),
                transpose_edges_after_each_iter=True)
        else:
            raise NotImplementedError

        lin_final = torch.nn.Linear(self.parameter["emb_dim"], self.emb_dim).to(self.device)

        mlp_bert = torch.nn.Sequential(torch.nn.Linear(bert_dim, 2 * bert_dim), torch.nn.ReLU(),
                                       torch.nn.Linear(2 * bert_dim, bert_dim),
                                       torch.nn.Linear(bert_dim, bert_dim),
                                       torch.nn.ReLU(), torch.nn.Linear(bert_dim, self.emb_dim))

        self.model = SingleLayerGeneralGNN(background_gnn=self.gnn_module, metagraph_gnn=self.metagraph_gnn,
                                           deepset_module=torch.nn.Identity(), label_mlp=mlp_bert, input_mlp=lin_final,
                                           params=self.parameter)
        self.model.to(self.device)
        print_num_trainable_params(self.model)

        bert_model_name = self.parameter["bert_emb_model"]
        self.Bert = SentenceEmb(bert_model_name, device=self.device)

        params = list(self.model.parameters())

        self.optimizer = optim.AdamW(filter(lambda p: p.requires_grad, params),
                                     lr=self.learning_rate, weight_decay=self.parameter["weight_decay"])

        self.scheduler = transformers.get_linear_schedule_with_warmup(self.optimizer, 0, self.epoch)

        wandb.config.params = parameter
        wandb.run.log_code(".")
        wandb.watch(self.model, log_freq=100)

        self.state_dir = os.path.join(self.parameter['state_dir'], self.parameter['prefix'])
        if not os.path.isdir(self.state_dir):
            os.makedirs(self.state_dir)
        self.ckpt_dir = os.path.join(self.parameter['state_dir'], self.parameter['prefix'], 'checkpoint')
        if not os.path.isdir(self.ckpt_dir):
            os.makedirs(self.ckpt_dir)
        self.state_dict_file = ''

        # logging
        self.logging_dir = os.path.join(self.parameter['log_dir'], self.parameter['prefix'], 'data')
        self.cache_dir = os.path.join(self.parameter['log_dir'], "cache")
        if not os.path.isdir(self.cache_dir):
            os.makedirs(self.cache_dir)

        if not os.path.isdir(self.logging_dir):
            os.makedirs(self.logging_dir)
        else:
            print(self.logging_dir, "already exists!!!")
            sys.exit()

        self.all_saveable_modules = {
            "model": self.model
        }
        self.pretrained_model_run = self.parameter["pretrained_model_run"]
        if self.pretrained_model_run != "":
            self.reload(self.pretrained_model_run)

        # Data loader creation.
        if self.dataset_name == "arxiv":
            self.train_dataloader, self.val_dataloader, self.test_dataloader = self._build_dataloaders(dataset, self.dataset_name)
        else:
            self.train_dataloader, self.val_dataloader, self.test_dataloader = self._build_dataloaders_legacy(dataset)

    def _build_dataloaders(self, dataset, dataset_name):
        kwargs = {}
        kwargs["root"] = os.path.join(self.parameter["root"], "arxiv")
        kwargs["num_workers"] = self.parameter["workers"]
        kwargs["batch_size"] = self.parameter["batch_size"]
        kwargs["n_way"] = self.parameter["n_way"]
        kwargs["n_shot"] = self.parameter["n_shots"]
        kwargs["n_query"] = self.parameter["n_query"]
        kwargs["bert"] = self.Bert

        if dataset_name == "arxiv":
            get_dataloader = get_arxiv_dataloader
        else:
            raise NotImplementedError

        val_dataloader = get_dataloader(dataset, split="val", batch_count=20, **kwargs)
        test_dataloader = get_dataloader(dataset, split="test", batch_count=20, **kwargs)

        # Update the n_way, n_shot, n_query parameters with range objects for the dataset
        # This is only done for train
        if self.parameter["n_way_upper"] > 0:
            kwargs["n_way"] = range(kwargs["n_way"], self.parameter["n_way_upper"] + 1)
        if self.parameter["n_shots_upper"] > 0:
            kwargs["n_shot"] = range(kwargs["n_shot"], self.parameter["n_shots_upper"] + 1)
        if self.parameter["n_query_upper"] > 0:
            kwargs["n_query"] = range(kwargs["n_query"], self.parameter["n_query_upper"] + 1)

        train_dataloader = get_dataloader(dataset, split="train", batch_count=self.parameter["dataset_len_cap"], **kwargs)

        return train_dataloader, val_dataloader, test_dataloader

    def _build_dataloaders_legacy(self, datasets):
        # For fewshot, a bit different preprocessing...
        # Fixing and preprocessing the training, testing and validation set.
        print(f"Fewshot preprocessing and Dataloader creation for {self.dataset_name}")

        bert_name = self.parameter["bert_emb_model"]
        if self.original_features:
            bert_name = "original-features"
        shot_text = self.shots

        def preprocess_dataset(dataset, split):
            path = os.path.join(self.cache_dir, f"{split}_{self.dataset_name}_{bert_name}_{len(dataset)}_fewshot{shot_text}.pkl")

            if os.path.exists(path) and not self.invalidate_cache:
                processed_dataset = pickle.load(open(path, "rb"))
                print(f"Loaded {split} set with preprocessed embeddings from", path)
            else:
                print(f"Preprocessing {split} set for {self.dataset_name} with {bert_name} and {len(dataset)} examples")
                print("Will save to", path)
                t1 = time.time()
                processed_dataset = [preprocess_task(dataset[i], self.Bert) for i in tqdm(range(len(dataset)))]
                processed_dataset = [x for x in processed_dataset if x is not None]
                t2 = time.time()
                wandb.log({f"time_{split}_set": t2 - t1})
                pickle.dump(processed_dataset, open(path, "wb"))
                print(f"Saved {split} dataset for {self.dataset_name} to {path}")

            return processed_dataset

        train_ds, valid_ds, test_ds = datasets[0]
        train_set = preprocess_dataset(train_ds, "train")
        valid_set = preprocess_dataset(valid_ds, "valid")
        test_set = preprocess_dataset(test_ds, "test")

        # Cap the size of the dataset according to the argument.
        train_set = train_set[:self.dataset_len_cap]
        valid_set = valid_set[:self.dataset_len_cap]
        test_set = test_set[:self.dataset_len_cap]

        train_dataloader = get_dataset_dataloader(train_set, self.batch_size, torch.device("cpu"), self.Bert)
        valid_dataloader = get_dataset_dataloader(valid_set, self.batch_size, torch.device("cpu"), self.Bert)
        test_dataloader = get_dataset_dataloader(test_set, self.batch_size, torch.device("cpu"), self.Bert)

        return train_dataloader, valid_dataloader, test_dataloader

    def reload(self, wandb_run_path):
        print("Reload state dict from wandb run", wandb_run_path)
        model_filename = os.path.join(self.parameter['state_dir'], get_run_prefix(wandb_run_path), "state_dict")
        if os.path.isfile(model_filename):
            model_file = torch.load(model_filename, map_location=self.device)['model']
            self.model.load_state_dict(model_file)
            print("---> Loaded state dict from ", model_filename)
        else:
            raise RuntimeError('No state dict in {}!'.format(model_filename))

    def move_to_device(self, bt_response):
        batch = bt_response[0]
        return batch.to(self.device), bt_response[1].to(self.device), bt_response[2].to(self.device), bt_response[3].to(
            self.device), bt_response[4].to(self.device), bt_response[5].to(self.device)

    def get_loss_and_acc(self, y_true_matrix, y_pred_matrix):
        num_clases = y_true_matrix.shape[1]
        num_query_per_batch = y_true_matrix.shape[0] // self.batch_size
        # this is not good if number of query examples is different for different tasks in a batch !!! (e.g. fs-mol)

        loss = self.loss(y_pred_matrix, y_true_matrix)

        # Transpose within each batch
        loss2 = self.loss(y_pred_matrix.reshape(self.batch_size, num_query_per_batch, num_clases).transpose(1, 2).reshape(self.batch_size * num_clases, num_query_per_batch),
                y_true_matrix.reshape(self.batch_size, num_query_per_batch, num_clases).transpose(1, 2).reshape(self.batch_size * num_clases, num_query_per_batch))
        return (loss + loss2)/2, accuracy(y_true_matrix, y_pred_matrix)[2]

    def get_y_matrix(self, batch_res):
        graph, task_embeddings, y_true_matrix, metagraph_edges, metagraph_edge_attr, query_set_mask = batch_res
        y_true_matrix, y_pred_matrix = self.model.forward(graph, task_embeddings, y_true_matrix, metagraph_edges,
                                                            metagraph_edge_attr, query_set_mask)
        return y_true_matrix, y_pred_matrix

    def save_checkpoint(self, epoch):
        state_dict = {key: value.state_dict() for key, value in self.all_saveable_modules.items()}
        torch.save(state_dict, os.path.join(self.ckpt_dir, 'state_dict_' + str(epoch) + '.ckpt'))

    def del_checkpoint(self, epoch):
        path = os.path.join(self.ckpt_dir, 'state_dict_' + str(epoch) + '.ckpt')
        if os.path.exists(path):
            os.remove(path)
        else:
            raise RuntimeError('No such checkpoint to delete: {}'.format(path))

    def save_best_state_dict(self, best_epoch):
        new_state_dict_path = os.path.join(self.state_dir, 'state_dict')
        shutil.copy(os.path.join(self.ckpt_dir, 'state_dict_' + str(best_epoch) + '.ckpt'),
                    os.path.join(self.state_dir, 'state_dict'))
        print("Saved best model to {}".format(new_state_dict_path))
        self.best_state_dict_path = new_state_dict_path

    def train(self):

        # initialization
        best_epoch = 0
        best_value = 0
        test_accuracy_on_best = 0
        bad_counts = 0

        # training by epoch
        t_load, t_one_step = 0, 0
        pbar = tqdm(range(self.epoch))

        bad_counts = 0

        # n_batch_eff = len(self.train_set) // self.bseff
        # self.train_task_list = []
        # for i in range(n_batch_eff):
        #     self.train_task_list.append(self.batch_tasks(self.train_set[i * self.bseff:(i + 1) * self.bseff]))
        with torch.no_grad():
            test_loss, test_acc = self.do_one_step(self.test_dataloader, iseval=True)
            wandb.log({"start_test_acc": test_acc})  #  Test accuracy before training (if using e.g. a pretrained model etc.)

        for e in pbar:
            self.gnn_module.train()
            t1 = time.time()
            t2 = time.time()
            loss, acc = self.do_one_step(self.train_dataloader, iseval=False)
            t3 = time.time()
            wandb.log({"step_time": t3 - t2}, step=e)
            wandb.log({"train_loss": loss, "train_acc": acc}, step=e)  # loss and acc here are both floats
            t_load += t2 - t1
            t_one_step += t3 - t2
            pbar.set_description("load: %s, step: %s" % (t_load / (e + 1), t_one_step / (e + 1)))
            self.scheduler.step()
            # print the loss on specific epoch
            if e % self.print_epoch == 0:
                # loss_num = loss
                print("\n Loss:", loss)
            # save checkpoint on specific epoch
            if e % self.checkpoint_epoch == 0 and e != 0:
                print('Epoch  {} has finished, saving...'.format(e))
                self.save_checkpoint(e)

            if e % self.eval_epoch == 0 and e != 0:
                # print("Evaluating on validation set!")
                all_valid_scores = []
                with torch.no_grad():
                    val_loss, val_acc = self.do_one_step(self.val_dataloader, iseval=True)
                    all_valid_scores.append([val_loss, val_acc])
                valid_scores = torch.tensor(all_valid_scores)
                mean_scores = torch.mean(valid_scores, dim=0)
                mean_loss = mean_scores[0]
                mean_acc = mean_scores[1]
                if mean_acc >= best_value:
                    best_value = mean_acc
                    test_accuracy_on_best = self.do_one_step(self.test_dataloader, iseval=True)[1]
                    best_epoch = e
                    bad_counts = 0
                else:
                    print("Validation loss did not improve now for {} validation checkpoints".format(bad_counts))
                    bad_counts += 1
                    if bad_counts >= self.early_stopping_patience:
                        print("Early stopping at epoch {}".format(e))
                        break

                wandb.log({"mean_valid_loss": mean_loss, "mean_valid_acc": mean_acc, "all_valid_obj": all_valid_scores},
                          step=e)
                # Also evaluate on test set
                with torch.no_grad():
                    test_loss, test_acc = self.do_one_step(self.test_dataloader, iseval=True)
                    wandb.log({"test_acc": test_acc, "test_loss": test_loss}, step=e)
        print('Training has finished')
        print('\tBest epoch is {0} | {1} of valid set is {2:.3f}'.format(best_epoch, "acc", best_value))
        print("Testing accuracy is", test_accuracy_on_best)
        wandb.run.summary["best_epoch"] = best_epoch
        wandb.run.summary["final_validation_acc"] = best_value
        wandb.run.summary["final_test_acc"] = test_accuracy_on_best

        self.save_best_state_dict(best_epoch)

        print('Finish')
        wandb.finish()
        return best_value, test_accuracy_on_best, best_epoch
        # returns best-val-acc, best-test-acc, best-epoch

    def do_one_step(self, dataloader, iseval=False, eff_len=None):
        if not iseval:
            self.optimizer.zero_grad()
        else:
            torch.set_grad_enabled(False)  # disable gradient calculation
        ytrueall, ypredall = None, None
        for batch in tqdm(dataloader, leave=False):
            batch = [i.to(self.device) for i in batch]
            yt, yp = self.get_y_matrix(batch)  # apply the model
            loss, acc = self.get_loss_and_acc(yt, yp)  # get loss
            if not iseval:
                loss.backward()
                self.optimizer.step()
                self.optimizer.zero_grad()
            if ytrueall is None:
                ytrueall = yt
                ypredall = yp
            else:
                ytrueall = torch.cat((ytrueall, yt), dim=0)
                ypredall = torch.cat((ypredall, yp), dim=0)
        loss_global, acc_global = self.get_loss_and_acc(ytrueall, ypredall)
        if iseval:
            torch.set_grad_enabled(True)
        return loss_global.item(), acc_global


