
# Acknowledge: https://github.com/facebookresearch/InvarianceUnitTests/blob/main/scripts/models.py

from tkinter import Y
import torch
from torch.autograd import grad

from tqdm import tqdm 
import math

import numpy as np
import torch.nn as nn
import torch.optim as optim
from sklearn import linear_model

from .ERM import ERM
from .IGA import IGA


class LinearModel(torch.nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LinearModel, self).__init__()
        self.linear = torch.nn.Linear(input_dim, output_dim, bias=False)
        # self.linear.weight.data = torch.ones((output_dim, input_dim))
        torch.nn.init.xavier_uniform_(self.linear.weight)
        
    def forward(self, x):
        return self.linear(x)



class EIIL(IGA):
    def __init__(self, input_dim, output_dim, lam, num_clusters=2, type='regression', device="cpu", task='simulation', arch='linear'):
        
        super(EIIL, self).__init__(input_dim=input_dim, output_dim=output_dim, lam=lam, type=type, arch=arch)
        # self.loss = torch.nn.MSELoss()
        self.model = LinearModel(input_dim, output_dim)
        
        self.loss = torch.nn.MSELoss()
        
        if self.type == 'regression':
            self.loss = torch.nn.MSELoss()
        elif self.type == 'classification':
             self.loss = torch.nn.BCEWithLogitsLoss()
        else:
            raise Exception('Not Implemented...')
        self.device = device
        self.task = task

        self.mask = None
        
        self.lam = lam
        self.erm_model = None


        self.domains = None
        
    def combine_envs(self, envs):
        X = []
        y = []
        
        for env in envs:
            X_i, y_i = env
            X.append(X_i)
            y.append(y_i)
        X = torch.cat(X, dim=0)
        y = torch.cat(y, dim=0)
        return X.reshape(-1, X.shape[1]), y.reshape(-1,1)

    def split(self, envs, iters=10000):
        X, y = self.combine_envs(envs)
        print('size of pooled envs: ', str(len(X)))
        
        env_w = torch.randn(len(y)).requires_grad_()
        w_optimizer = torch.optim.Adam([env_w], lr=0.001)
        
        loss = torch.nn.MSELoss(reduction='none')
        
        self.phi = torch.nn.Parameter(torch.diag(self.model.weight.data.clone().squeeze()))
        self.w = torch.ones(X.shape[1], 1)
        self.w.requires_grad = True
        
        error = loss(X @ self.phi @ self.w, y)
        
        
        print('learning soft environment assignments')
        with tqdm(total=iters,
                    position=1,
                    bar_format='{desc}',
                    desc='negative penalty: ',
                   ) as desc:
            for i in tqdm(range(iters)):
                # penalty for env a
                error_a = (error.squeeze() * env_w.sigmoid()).mean()
                penalty_a = grad(error_a, self.w, create_graph=True)[0].pow(2).mean()
                # penalty for env b
                error_b = (error.squeeze() * (1-env_w.sigmoid())).mean()
                penalty_b = grad(error_b, self.w, create_graph=True)[0].pow(2).mean()
                # negate
                npenalty = - torch.stack([penalty_a, penalty_b]).mean()
                desc.set_description('negative penalty: '+ str(npenalty))

                w_optimizer.zero_grad()
                npenalty.backward(retain_graph=True)
                w_optimizer.step()    


        envs = []
        idx0 = (env_w.sigmoid()>.5)
        idx1 = (env_w.sigmoid()<=.5)
        envs.append((X[idx0], y[idx0]))
        print('size of env 0: ' + str(len(X[idx0])))
        envs.append((X[idx1],y[idx1]))
        print('size of env 1: ' + str(len(X[idx1])))
        
        if len(env_w) < 33:
            print('weights: ' + str(env_w.sigmoid()))
        
        return envs
    
    def solve(self, data, epochs=1000, lr=1e-3):

        X, y = data[0]
        
        print('err of pre-erm: ', self.loss(self.model(X.to(self.device)), y.to(self.device)))
        
        # print(X.shape, y.shape)
        envs = self.split(data)
        self.train(envs, epochs=epochs, lr=lr)
            
        return self
    

    def train(self, envs, epochs=6000, lr=1e-3, renew=False, verbose=False):
        
        print('envs: ')
        envs_torch = []
        for env in envs:
            x, y = env
            envs_torch.append((torch.Tensor(x), torch.Tensor(y)))
            print(x.shape, y.shape)
        envs = envs_torch

        # print('pre-train...')
        
        # self.pretrain(envs)
        # print('set the pre-train model : model <- erm-model')
        # self.model = self.erm_model
        
        print('fine-tuning...')
        
        opt = torch.optim.SGD([{'params': self.model.parameters(), 'lr': lr}], weight_decay=0.)
        # opt = torch.optim.SGD([{'params': self.model.parameters(), 'lr': lr}], weight_decay=1.)
        if self.task == 'income':
            opt = torch.optim.SGD([{'params': self.model.parameters(), 'lr': lr}], weight_decay=1.)
        if self.task == 'insurance':
            opt = torch.optim.SGD([{'params': self.model.parameters(), 'lr': lr}], weight_decay=1.)
        if self.task == 'mnist':
            opt = torch.optim.Adam(self.model.parameters(), lr)
        
        num_envs = len(envs)
        loss_avg = 0.0
        grad_avg = 0.0
        grad_list = []
        
        erm_curve = []
        penalty_curve = []
        
        for iter in range(epochs):
            
            losses = []
            for x_e , y_e in envs:
                
                x_e = x_e.to(self.device)
                y_e = y_e.to(self.device)
                
                p = range(len(x_e))
                pred = self.model(x_e[p])
                
                loss_e = self.loss(pred, y_e[p].reshape(pred.shape))
                losses.append(loss_e)
                
                # if iter % 10 == 0:
                #     print('losse e :  ', loss_e, 'env: #ins: ', len(x_e))
                
            gradients = [
                grad(loss, self.model.parameters(), create_graph=True)
                for loss in losses
            ]
            
            avg_loss = sum(losses) / num_envs
            avg_gradient = grad(avg_loss, self.model.parameters(), create_graph=True)
            
            # compute trace penalty
            penalty_value = 0
            
            for gradient in gradients:
                for gradient_i, avg_grad_i in zip(gradient, avg_gradient):
                    penalty_value += (gradient_i - avg_grad_i).pow(2).sum()
                
            
            opt.zero_grad()
            (self.lam * avg_loss + penalty_value).backward()
            # (self.lam * avg_loss).backward()
            opt.step()
                
            if verbose:
                if iter % 100 == 0:
                    # print('model: ', self.model.linear.weight.data)
                    print('erm error', self.lam * avg_loss.data.cpu().numpy())
                    print('grad penalty', penalty_value.data.cpu().numpy())
            
            
            erm_curve.append(avg_loss.data.cpu().numpy())
            penalty_curve.append(penalty_value.data.cpu().numpy())
        
        print(' fine-tuning: ')
        print('erm: ', erm_curve[0], ' -> ', erm_curve[-1])
        print('penalty: ', penalty_curve[0], ' -> ', penalty_curve[-1])
        
        return None, self.model.weight.data
        
        # import matplotlib.pyplot as plt
        # plt.plot(range(len(erm_curve)), erm_curve, label='erm')
        # plt.legend()
        # plt.show()
        # plt.plot(range(len(erm_curve)), penalty_curve, label='grad')
        # plt.legend()
        # plt.show()
                    
    '''
    def predict(self, X):
        
        X = torch.Tensor(X).to(self.device)


        X = X * self.mask
        return self.model(X).detach().cpu().numpy()
    '''
    
    def set_model(self, model):
        self.model = model
        if self.device is not None:
            self.model = self.model.to(self.device)
        return self
    def generate_mask(self, the_candidates):
        
        weights = torch.abs(self.model.weight.data).detach().reshape(-1)
        
        weights /= weights.sum()
        sorted_w, idx = torch.sort(weights)
        # print(weights)
        
        pre_sum = torch.cumsum(sorted_w, dim=0)
        
        for the in the_candidates:
            important_idx = torch.where(pre_sum >= (1. - the))
            # print(pre_sum, important_idx)
            mask = torch.zeros(len(pre_sum))
            mask[idx[important_idx]] = 1.
            # print('mask: ', mask)
            print('the: ', the)
            print('select {} from {}'.format(mask.sum(), len(mask)))
        
        return mask
    '''
    def score(self, X, y):
        
        X = torch.Tensor(X).to(self.device)
        y = torch.Tensor(y).to(self.device)
        
        # if self.mask is not None:
        #     X = X * self.mask
        
        X = self.backend.featureSelector(X)
        

        with torch.no_grad():
            pred = self.backend.backmodel(X).detach()
            pred = pred.reshape(-1, self.out_dim)
        
        if self.type == 'classification':
            # Accuracy
            if self.out_dim == 1:
                pred[pred > 0] = 1
                pred[pred < 1] = 0
                score = (pred.eq(y).sum() / len(pred)).item()
            else:
                pred = torch.argmax(pred, dim=1)
                score = (pred.eq(torch.argmax(y, dim=1)).sum() / len(pred)).item()
        elif self.type == 'regression':
            # MSE
            score = ((pred - y.reshape(pred.shape)) ** 2).mean().item()
        else:
            raise Exception('Not Implemented')
        return score
    '''


        
    
        