import os
import numpy as np
from scipy.optimize import linear_sum_assignment


def cluster_acc(y_true, y_pred, return_ind=False):
    
    y_true = y_true.astype(int)
    assert y_pred.size == y_true.size
    D = max(y_pred.max(), y_true.max()) + 1
    w = np.zeros((D, D), dtype=int)
    for i in range(y_pred.size):
        w[y_pred[i], y_true[i]] += 1

    ind = linear_sum_assignment(w.max() - w)
    ind = np.vstack(ind).T

    if return_ind:
        return sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size, ind, w
    else:
        return sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size


def cluster_acc_old_only(y_true, y_pred):

    y_true = y_true.astype(int)
    y_pred = y_pred.astype(int)

    old_acc = cluster_acc(y_true, y_pred)

    return old_acc


# GCD metric (arxiv v1)
def cluster_acc_v1(y_true, y_pred, mask):

    mask = mask.astype(bool)
    y_true = y_true.astype(int)
    y_pred = y_pred.astype(int)
    weight = mask.mean()

    old_acc = cluster_acc(y_true[mask], y_pred[mask])
    new_acc = cluster_acc(y_true[~mask], y_pred[~mask])
    total_acc = weight * old_acc + (1 - weight) * new_acc

    return total_acc, old_acc, new_acc


# Official GCD metric (CVPR22, arxiv v2)
def cluster_acc_v2(y_true, y_pred, mask):
    y_true = y_true.astype(int)

    old_classes_gt = set(y_true[mask])
    new_classes_gt = set(y_true[~mask])
    
    assert y_pred.size == y_true.size
    D = max(y_pred.max(), y_true.max()) + 1
    w = np.zeros((D, D), dtype=int)
    for i in range(y_pred.size):
        w[y_pred[i], y_true[i]] += 1

    ind = linear_sum_assignment(w.max() - w)
    ind = np.vstack(ind).T
    ind_map = {j: i for i, j in ind}
    total_acc = sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size
    
    old_acc = 0
    total_old_instances = 0
    for i in old_classes_gt:
        old_acc += w[ind_map[i], i]
        total_old_instances += sum(w[:, i])
    old_acc /= total_old_instances

    new_acc = 0
    total_new_instances = 0
    for i in new_classes_gt:
        new_acc += w[ind_map[i], i]
        total_new_instances += sum(w[:, i])
    new_acc /= total_new_instances

    return total_acc, old_acc, new_acc


# ORCA metrics
def orca_accuracy(output, target):
    
    num_correct = np.sum(output == target)
    res = num_correct / len(target)
    
    return res

def orca_cluster_acc(y_pred, y_true):
    
    y_true = y_true.astype(np.int64)
    assert y_pred.size == y_true.size
    D = max(y_pred.max(), y_true.max()) + 1
    w = np.zeros((D, D), dtype=np.int64)
    for i in range(y_pred.size):
        w[y_pred[i], y_true[i]] += 1
    row_ind, col_ind = linear_sum_assignment(w.max() - w)
    
    return w[row_ind, col_ind].sum() / y_pred.size

def orca_all_old_new_ACCs(unlab_preds, unlab_gt_labs, seen_mask):
    
    orca_all_acc = orca_cluster_acc(unlab_preds, unlab_gt_labs)
    orca_old_acc = orca_accuracy(unlab_preds[seen_mask], unlab_gt_labs[seen_mask])
    orca_new_acc = orca_cluster_acc(unlab_preds[~seen_mask], unlab_gt_labs[~seen_mask])
    
    return orca_all_acc, orca_old_acc, orca_new_acc

def partitioning_eval(unlab_gt_labs, unlab_preds, seen_mask, dset_name, path_k_strat):

    # GCD metrics (used in GCD arxiv v1)
    accs_v1 = cluster_acc_v1(unlab_gt_labs, unlab_preds, seen_mask)
    
    # GCD metrics (used in GCD CVPR-22 and GCD arxiv v2)
    accs_v2 = cluster_acc_v2(unlab_gt_labs, unlab_preds, seen_mask)
    
    # ORCA metrics (used in ORCA ICLR-22)
    orca_accs = orca_all_old_new_ACCs(unlab_preds, unlab_gt_labs, seen_mask)
    
    
    print("Classes:                      (All) & (Old) & (New)")
    
    print("PIM ACC (v1):                 ", np.round(100. * accs_v1[0], 1),
                                      "& ", np.round(100. * accs_v1[1], 1), 
                                      "& ", np.round(100. * accs_v1[2], 1))
    
    print("PIM ACC (ORCA metric):        ", np.round(100. * orca_accs[0], 1),
                                      "& ", np.round(100. * orca_accs[1], 1), 
                                      "& ", np.round(100. * orca_accs[2], 1))
    
    print("PIM ACC (Official GCD metric):", np.round(100. * accs_v2[0], 1),
                                      "& ", np.round(100. * accs_v2[1], 1), 
                                      "& ", np.round(100. * accs_v2[2], 1))
    
    path_accs_v2 = 'util/params_estim/' + path_k_strat + '/' + dset_name + '/scores'
    if not os.path.exists(path_accs_v2):
        os.makedirs(path_accs_v2)
    
    accs_v2_file_name = path_accs_v2 + '/ACCs_v2.npy'
    with open(accs_v2_file_name, 'wb') as f:
        np.save(f, accs_v2)
    return 1