import numpy as np
import torch

from sklearn.ensemble import RandomForestClassifier
from sklearn.cluster import KMeans
from sklearn.metrics import accuracy_score,normalized_mutual_info_score

from scipy.optimize import linear_sum_assignment as linear_assignment

from munkres import Munkres

from Utils.metrics import reconstruction_loss

def compare_result(new_mean,new_std,old_mean,old_std,higher_flag):
    # flag:
    # 0: worse 
    # 1: better
    # 2: tie
    if higher_flag:
        if new_mean>old_mean:                           # win
            return 1
        elif new_mean==old_mean and new_std<old_std:    # win
            return 1
        elif new_mean==old_mean and new_std==old_std:   # tie
            return 2
        else:                                           # lose
            return 0
    else:
        if new_mean<old_mean:                           # win
            return 1
        elif new_mean==old_mean and new_std<old_std:    # win
            return 1
        elif new_mean==old_mean and new_std==old_std:   # tie
            return 2
        else:                                           # lose
            return 0

        
def check_best(best_res_mean,best_res_std,iter_res,higher_flag):
    iter_res = np.array(iter_res)
    new_mean = np.mean(iter_res,axis = 0)
    new_std = np.std(iter_res,axis = 0)

    if compare_result(new_mean[0], new_std[0],best_res_mean[0],best_res_std[0],higher_flag) == 1:
        return True
    elif compare_result(new_mean[0], new_std[0],best_res_mean[0],best_res_std[0],higher_flag) == 2 and len(new_mean)>1:
        if compare_result(new_mean[1], new_std[1],best_res_mean[1],best_res_std[1],higher_flag)==1 :
            return True
        else:
            return False
    else:
        return False

class ConstructNet(torch.nn.Module):
	def __init__(self, n_feature, n_hidden, n_output):
		super(ConstructNet, self).__init__()
		self.hidden = torch.nn.Linear(n_feature, n_hidden)
		self.predict = torch.nn.Linear(n_hidden, n_output)
		self.relu = torch.nn.ReLU()

	def forward(self, x):
		x = self.relu(self.hidden(x))
		out = self.predict(x)
		return out

def init_weights(m):
    if type(m) in [torch.nn.Linear, torch.nn.Conv2d]:
        torch.nn.init.normal_(m.weight, std=0.01)

def cluster_assign(y_true, y_pred):
    y_true = y_true.astype(np.int64)
    assert y_pred.size == y_true.size
    D = max(y_pred.max(), y_true.max()) + 1
    w = np.zeros((D, D), dtype=np.int64)
    for i in range(y_pred.size):
        w[y_pred[i], y_true[i]] += 1
    row_ind,col_ind = linear_assignment(w.max() - w)
    new_class = np.zeros_like(y_pred)
    for i,j in zip(row_ind,col_ind):
        new_class[y_pred==j] = i
    return new_class
    

def run_baseline(X_train,X_test,target_train,target_test,baseline_model,iter,save_flag = False,device = torch.device('cuda:0')):
    if baseline_model == 'RF':                  # for Classification
        if len(target_train.shape)>1:
            if target_train.shape[1]!=1:
                target_train = np.argmax(target_train,axis = 1)
                target_test = np.argmax(target_test,axis = 1)
        clf = RandomForestClassifier(n_estimators=1000, random_state=iter)
        clf.fit(X_train, target_train)
        pred = clf.predict(X_test)
        res = [accuracy_score(target_test,pred),]

    elif baseline_model == 'KMEANS':            # for Clustering
        if len(target_train.shape)>1:
            if target_test.shape[1]!=1:
                target_test = np.argmax(target_test,axis = 1)
        c = len(list(set(target_test)))
        model=KMeans(n_clusters=c)
        model.fit(X_test)
        pred = model.labels_
        l1 = list(set(target_test))
        numclass1 = len(l1)

        l2 = list(set(pred))
        numclass2 = len(l2)
        assert numclass1 == numclass2

        cost = np.zeros((numclass1, numclass2), dtype=int)
        for i, c1 in enumerate(l1):
            mps = [i1 for i1, e1 in enumerate(target_test) if e1 == c1]
            for j, c2 in enumerate(l2):
                mps_d = [i1 for i1 in mps if pred[i1] == c2]

                cost[i][j] = len(mps_d)

        # match two clustering results by Munkres algorithm
        m = Munkres()
        cost = cost.__neg__().tolist()

        indexes = m.compute(cost)

        # get the match results
        new_predict = np.zeros(len(pred))
        for i, c in enumerate(l1):
            # correponding label in l2:
            c2 = l2[indexes[i][1]]

            # ai is the index with label==c2 in the pred_label list
            ai = [ind for ind, elm in enumerate(pred) if elm == c2]
            new_predict[ai] = c

        res_acc = accuracy_score(target_test, new_predict)
        res = [res_acc,]

    elif baseline_model =='NN':                 # for Reconstruction
        X_train = torch.tensor(X_train).to(device)
        X_test = torch.tensor(X_test).to(device)
        X_train_ori = torch.tensor(target_train).to(device)
        X_test_ori = torch.tensor(target_test).to(device)

        trainnum = X_train.shape[0]
        testnum = X_test.shape[0]
        d = X_train_ori.shape[1]
        k = X_train.shape[1]
        # hdim = 3*k//2
        hdim = k

        net = ConstructNet(k,hdim,d)
        net.apply(init_weights)
        net = net.to(device)

        trainer = torch.optim.Adam(net.parameters(), 1e-3)
        for _ in range(2000):
            trainer.zero_grad()
            prediction  = net(X_train)
            l = reconstruction_loss(prediction,X_train_ori)
            recon_loss = l.sum()/trainnum
            recon_loss.backward(retain_graph=True)
            trainer.step()
        with torch.no_grad():
            pred = net(X_test)
            l = reconstruction_loss(pred,X_test_ori)
            rmse = torch.sqrt(l.sum()/(testnum*d))
        res = [rmse.detach().cpu().numpy().tolist(),]
        if save_flag:
            X_tr_gt = X_train_ori.detach().cpu().numpy()
            X_tr_pred = prediction.detach().cpu().numpy()
            X_te_gt = X_test_ori.detach().cpu().numpy()
            X_te_pred = pred.detach().cpu().numpy()
            saved_data = [X_tr_gt,X_tr_pred,X_te_gt,X_te_pred]
            res.append(saved_data)
    
    return res
        