import os
import time
from tqdm import tqdm, trange
import numpy as np
import torch
import random
import torch.nn.functional as F
import copy
from utils.loader import load_seed, load_device, load_data, load_model_params, load_model_optimizer, load_loss_fn, \
                         load_simple_model_optimizer, load_p_model_optimizer, load_simple_loss_fn
from utils.logger import Logger, set_log, start_log, train_log

import scipy.sparse as sp
from datetime import datetime

class Trainer(object):
    def __init__(self, config):
        super(Trainer, self).__init__()

        self.config = config
        self.log_folder_name, self.log_dir = set_log(self.config)
        self.seed = load_seed(self.config.seed)
        self.device = load_device()
        self.x, self.y, self.adj, self.train_mask, self.valid_mask, self.test_mask = load_data(self.config)

    def update(self, inputs, target, idx, model, optimizer):
        inputs = inputs.cuda()
        target = target.cuda()
        idx = idx.cuda()

        model.train()
        optimizer.zero_grad()

        logits = model(inputs)
        loss = F.nll_loss(logits[idx], target[idx].argmax(dim=1))
        
        loss.backward()
        optimizer.step()
        return loss.item()

    def update_soft(self, inputs, adj,target, idx_train, idx_remain = None, model=None, optimizer=None):
        inputs = inputs.cuda()
        target = target.cuda()
        idx_train = idx_train.cuda()

        model.train()
        optimizer.zero_grad()

        logits = model(inputs, adj)
        loss = F.nll_loss(logits[idx_train], target[idx_train].argmax(dim=1))
        if idx_remain != None:
            loss = loss + F.nll_loss(logits[idx_remain], target[idx_remain].argmax(dim=1))

        loss.backward()
        optimizer.step()
        return loss.item()
    
    def evaluate(self, inputs, adj, target, idx, model):
        inputs = inputs.cuda()
        target = target.cuda()
        idx = idx.cuda()

        model.eval()
        logits = model(inputs, adj)
        preds = torch.max(logits[idx], dim=1)[1]
        node_acc = (preds==target[idx]).float().mean().item()
        return node_acc, torch.max(logits, dim=1)[1]

    def predict(self, inputs, adj, model,tau=1):
        inputs = inputs.cuda()
        
        model.eval()
        logits = model(inputs, adj) / tau
        logits = torch.exp(logits).detach()
        return logits



    def train(self, ts):
        self.config.exp_name = ts
        self.ckpt = f'{ts}'
        print('\033[91m' + f'{self.ckpt}' + '\033[0m')

        # Prepare model, optimizer, and logger
        self.params = load_model_params(self.config)
        gnnq, q_optimizer, scheduler_q = load_simple_model_optimizer(self.params, self.config.train, self.device, True)
        self.params['nfeat'] = self.params['nlabel']
        gnnp, p_optimizer, scheduler_p = load_p_model_optimizer(self.params, self.config.train, self.device, True)
        
        logger = Logger(str(os.path.join(self.log_dir, f'{self.ckpt}.log')), mode='a')
        logger.log(f'{self.ckpt}', verbose=False)
        start_log(logger, self.config)
        train_log(logger, self.config)


        # Prepare data
        inputs, target, edges, idx_train, idx_dev, idx_test =  self.x, self.y, self.adj, self.train_mask, self.valid_mask, self.test_mask
        target = torch.argmax(target, dim= 1).cuda()
        nlabel = torch.max(target)+1
        idx_all = list(range(0, inputs.shape[0]))
        idx_train = torch.where(idx_train == True)[0]
        idx_dev = torch.where(idx_dev == True)[0]
        idx_test = torch.where(idx_test == True)[0]
        idx_remain = torch.cat([idx_dev,idx_test], dim = 0)

        adj = edges.cuda()
        idx_all = torch.LongTensor(idx_all).cuda()
        inputs_q = torch.zeros((inputs.shape[0], inputs.shape[1])).cuda()
        target_q = torch.zeros((inputs.shape[0], nlabel)).cuda()
        inputs_p = torch.zeros((inputs.shape[0], nlabel)).cuda()
        target_p = torch.zeros((inputs.shape[0], nlabel)).cuda()


        def init_q_data():
            inputs_q.copy_(inputs)
            temp = torch.zeros(idx_train.size(0), target_q.size(1)).type_as(target_q)
            temp.scatter_(1, torch.unsqueeze(target[idx_train], 1), 1.0)
            target_q[idx_train] = temp

        def update_p_data():
            preds = self.predict(inputs_q, adj, gnnq, self.config.em.temp)

            idx_lb = torch.multinomial(preds, 1).squeeze(1)
            inputs_p.zero_().scatter_(1, torch.unsqueeze(idx_lb, 1), 1.0)
            target_p.zero_().scatter_(1, torch.unsqueeze(idx_lb, 1), 1.0)
            
            temp = torch.zeros(idx_train.size(0), target_q.size(1)).type_as(target_q)
            temp.scatter_(1, torch.unsqueeze(target[idx_train], 1), 1.0)
            inputs_p[idx_train] = temp
            target_p[idx_train] = temp

        def update_q_data():
            preds = self.predict(inputs_p, adj, gnnp)
            target_q.copy_(preds)
            temp = torch.zeros(idx_train.size(0), target_q.size(1)).type_as(target_q)
            temp.scatter_(1, torch.unsqueeze(target[idx_train], 1), 1.0)
            target_q[idx_train] = temp

        def pre_train(epoches):
            best = 0.0
            init_q_data()
            results = []
            for epoch in range(epoches):
                loss = self.update_soft(inputs_q, adj, target_q, idx_train, model = gnnq, optimizer = q_optimizer)
                if epoch > 0:
                    accuracy_dev, preds = self.evaluate(inputs_q, adj, target, idx_dev, model = gnnq)
                    accuracy_test, preds = self.evaluate(inputs_q, adj, target, idx_test, model = gnnq)
                    results += [(accuracy_dev, accuracy_test)]
                    if accuracy_dev > best:
                        best = accuracy_dev
                        state = dict([('model', copy.deepcopy(gnnq.state_dict())), ('optim', copy.deepcopy(q_optimizer.state_dict()))])
            gnnq.load_state_dict(state['model'])
            q_optimizer.load_state_dict(state['optim'])
            return results

        def train_p(epoches):
            results = []
            update_p_data()
            for epoch in range(epoches):
                loss = self.update_soft(inputs_p, adj, target_p, idx_train, idx_remain, model = gnnp, optimizer = p_optimizer)
                if epoch > 0:
                    accuracy_dev, preds = self.evaluate(inputs_p, adj, target, idx_dev, model = gnnp)
                    accuracy_test, preds = self.evaluate(inputs_p, adj, target, idx_test, model = gnnp)
                    results += [(accuracy_dev, accuracy_test)]

            return results

        def train_q(epoches):
            results = []
            update_q_data()
            for epoch in range(epoches):
                loss = self.update_soft(inputs_q, adj, target_q, idx_train, idx_remain, model = gnnq, optimizer = q_optimizer)
                if epoch > 0:
                    accuracy_dev, preds = self.evaluate(inputs_q, adj, target, idx_dev, model = gnnq)
                    accuracy_test, preds = self.evaluate(inputs_q, adj, target, idx_test, model = gnnq)
                    results += [(accuracy_dev, accuracy_test)]

            return results

        def get_accuracy(results):
            best_dev, acc_test, best_node = 0.0, 0.0, 0.0
            for d, t in results:
                if d > best_dev:
                    best_dev, acc_test = d, t
            return best_dev, acc_test 
        

        # Training
        best_valid = 0
        results = []
        results += pre_train(self.config.train.num_epochs)
        for k in range(self.config.em.iteration):
            temp_results = train_p(self.config.train.num_epochs)
            results += temp_results
            temp_results = train_q(self.config.train.num_epochs)
            results += temp_results

            dev_acc, test_acc = get_accuracy(results)
            logger.log(f'Iteration {k+1:02d} | best val: {dev_acc:.3e} | best test: {test_acc:.3e}', verbose=False)
            print(f'[Iteration {k+1:02d}] | best val: {dev_acc:.3e} | best test: {test_acc:.3e}', end = '\r')
