


import os
import sys
import numpy as np
import torch
import math
import argparse
import pickle
import logging

from sklearn.linear_model import LinearRegression, Ridge
from torch._C import device

from utils.tools import combine_envs


class ERM():
    def __init__(self, input_dim, output_dim, type, arch='linear', task = 'simulation', device = "cpu"):
        
        if not type in ['regression', 'classification']:
            raise Exception('choices of types: [regression, classification]')
        self.type = type
        
        if type == 'regression':
            self.loss = torch.nn.MSELoss()
        elif type == 'classification':
             self.loss = torch.nn.BCEWithLogitsLoss() # modified on 12-22
             # self.loss = torch.nn.MSELoss()# modified on 12-22
        else:
            raise Exception('Not Implemented...')
        
        self.task = task
                
        self.out_dim = output_dim
        
        self.model = None
        if arch == 'linear':
            self.model = torch.nn.Linear(input_dim, output_dim, bias=False).to(device)
        elif arch == 'mlp':
            self.model = MLP().to(device)
        else:
            raise Exception('Not Implemented...')
        self.arch = arch
        
        self.mask = None
        self.device = device
        
        
    
    def train(self, envs, epochs, lr=1e-2, renew=False, verbose=False):
        
        envs_torch = []
        for env in envs:
            x, y = env
            envs_torch.append((torch.Tensor(x), torch.Tensor(y)))
        envs = envs_torch
        
        X, y = combine_envs(envs)
        
        # opt = torch.optim.SGD([{'params': self.model.parameters(), 'lr': lr}], weight_decay=1.)
        if self.task == 'simulation':
            opt = torch.optim.Adam([{'params': self.model.parameters(), 'lr': lr}])
        elif self.task == 'house':
            opt = torch.optim.SGD([{'params': self.model.parameters(), 'lr': lr}], weight_decay=1.)
        elif self.task == 'income':
            opt = torch.optim.SGD([{'params': self.model.parameters(), 'lr': lr}], weight_decay=1.)
        elif self.task == 'insurance':
            opt = torch.optim.SGD([{'params': self.model.parameters(), 'lr': lr}], weight_decay=1.)
        
        
        X = X.to(self.device)
        y = y.to(self.device)
        
        erm_curve = []
        for epoch_i in range(epochs):
            
            opt.zero_grad()
            
            
            pred = self.model(X)
            loss = self.loss(pred, y.reshape(pred.shape))
            
            if verbose and epoch_i % 100:
                print(loss.item())
            
            loss.backward()
            opt.step()
            
            erm_curve.append(loss.item())
        
        # import matplotlib.pyplot as plt
        # plt.plot(range(len(erm_curve)), erm_curve, label='erm')
        # plt.legend()
        # plt.show()
        
        return self
        
    def predict(self, X):
        X = torch.Tensor(X).to(self.device)
        return self.model(X)
    
    def set_model(self, model):
        self.model = model
        return self

    def set_mask(self, mask):
        self.mask = mask
        return self
    
    def score(self, X, y):
        
        X = torch.Tensor(X).to(self.device)
        y = torch.Tensor(y).to(self.device)
        
        if self.mask is not None:
            X = X * self.mask
        
        with torch.no_grad():
            pred = self.model(X).detach()
            pred = pred.reshape(-1, self.out_dim)
        
        if self.type == 'classification':
            # Accuracy
            if self.out_dim == 1:
                pred[pred > 0] = 1
                pred[pred < 1] = 0
                score = (pred.eq(y).sum() / len(pred)).item()
            else:
                pred = torch.argmax(pred, dim=1)
                score = (pred.eq(torch.argmax(y, dim=1)).sum() / len(pred)).item()
        elif self.type == 'regression':
            # MSE
            score = ((pred - y.reshape(pred.shape)) ** 2).mean().item()
        else:
            raise Exception('Not Implemented')
        return score

def Regression(X, y):
    model = Ridge(fit_intercept=False).fit(X, y.reshape(-1))
    return model




class MLP(torch.nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
      
        lin1 = torch.nn.Linear(2 * 14 * 14, 256)
        lin2 = torch.nn.Linear(256, 256)
        lin3 = torch.nn.Linear(256, 1)
        for lin in [lin1, lin2, lin3]:
            torch.nn.init.xavier_uniform_(lin.weight)
            torch.nn.init.zeros_(lin.bias)
        self.fea_module =  torch.nn.Sequential(lin1, torch.nn.ReLU(True), lin2)
        self.clf = torch.nn.Sequential(torch.nn.ReLU(True), lin3)
        
    def forward(self, input, need_fea=False):
        input = input.view(input.shape[0], 2 * 14 * 14)
        fea = self.fea_module(input)
        out = self.clf(fea)
        if need_fea:
            return fea, out
        return out