from tqdm import tqdm
import torch
import torch.nn.functional as F
import os
import pickle
from sklearn.cluster import KMeans
import numpy as np
from training.data_utils import get_dataloaders
from utils.forget_scores import calculate_forget_scores


def get_aus_stats_from_fs_stats(fs_stats):
    list_labels = []
    list_logits = []
    list_acts = []
    list_grad_norm = []
    list_obj = []
    list_margins = []
    list_orig_idx = []
    
    for example_id, example_stats in tqdm(fs_stats.items()):
        if not isinstance(example_id, str):
            list_labels.append(example_stats[8][0])
            list_logits.append(torch.stack(example_stats[10]))
            list_acts.append(torch.stack(example_stats[9]))
            list_grad_norm.append(example_stats[7])
            list_obj.append(example_stats[0])
            list_margins.append(example_stats[2])
            list_orig_idx.append(example_stats[6][0])
    
    stats_dict = {
        'list_labels': torch.tensor(list_labels).cpu(),
        'list_logits': torch.stack(list_logits).mean(dim=1).cpu(),
        'list_acts': torch.stack(list_acts).mean(dim=1).cpu(),
        'list_grad_norm': torch.tensor(list_grad_norm).mean(dim=1).cpu(),
        'list_obj': torch.tensor(list_obj).mean(dim=1).cpu(),
        'list_margins': torch.tensor(list_margins).mean(dim=1).cpu(),
        'list_orig_idx': torch.tensor(list_orig_idx).cpu()
    }
    
    return stats_dict


def get_aus_stats(args, model, criterion, train_loader):
    if args.aus_avg_epochs:
        # Load precomputed stats
        # This will be slightly different due to model mode
        print('Using averaged statistics to perform AUS.')

        with open(args.stats_dict_path, 'rb') as f:
            stats_dict = pickle.load(f)
        assert stats_dict['get_fs_stats'], "Need to calculate fs stats during training"
        fs_stats = stats_dict['stats']
        
        return get_aus_stats_from_fs_stats(fs_stats)
    
    model.eval()

    list_labels = []
    list_logits = []
    list_acts = []
    list_conf_score = []
    list_pred = []
    list_ll_grad = []
    list_grad_norm = []
    list_obj = []
    list_margins = []
    list_orig_idx = []
    id_mat = torch.eye(args.num_classes).cuda()

    with torch.no_grad():
        for x, y, orig_idx, _ in tqdm(train_loader, desc='Collecting aus info'):
            x, y = x.to(args.device), y.to(args.device)
            logits = model(x)
            acts = F.softmax(logits, dim=1)
            conf_score, pred = acts.max(1)
            ll_grad = acts - id_mat[y.long()]
            grad_norm = ll_grad.norm(dim=1)
            obj = criterion(logits, y)
            logits_correct_classes = logits[torch.arange(logits.size(0)), y]
            logits_copy = logits.clone()
            logits_copy[torch.arange(logits.size(0)), y] = float('-inf')
            logits_highest_incorrrect_classes = torch.max(logits_copy, dim=1).values
            margins = logits_correct_classes - logits_highest_incorrrect_classes
            
            list_labels.append(y)
            list_logits.append(logits)
            list_acts.append(acts)
            list_conf_score.append(conf_score)
            list_pred.append(pred)
            list_ll_grad.append(ll_grad)
            list_grad_norm.append(grad_norm)
            list_obj.append(obj)
            list_margins.append(margins)
            list_orig_idx.append(orig_idx)
            
    stats_dict = {
        'list_labels': torch.cat(list_labels).cpu(),
        'list_logits': torch.cat(list_logits).cpu(),
        'list_acts': torch.cat(list_acts).cpu(),
        'list_conf_score': torch.cat(list_conf_score).cpu(),
        'list_pred': torch.cat(list_pred).cpu(),
        'list_ll_grad': torch.cat(list_ll_grad).cpu(),
        'list_grad_norm': torch.cat(list_grad_norm).cpu(),
        'list_obj': torch.cat(list_obj).cpu(),
        'list_margins': torch.cat(list_margins).cpu(),
        'list_orig_idx': torch.cat(list_orig_idx).cpu()
    }
    
    return stats_dict

def calculate_score(args, stats_dict):
    aus_score_data = torch.tensor([])
    reversed=True

    if args.aus_score == 'grad_norm':
        aus_score_data = stats_dict['list_grad_norm']
        reversed=False

    elif args.aus_score == 'cluster_size':
        aus_score_data = stats_dict['cluster_sizes']

    elif args.aus_score == 'forget_score':
        args.forget_score_epochs = [0,args.aus_epochs]
        args.fs_postprocess = False
        calculate_forget_scores(args)
        fs_path = os.path.join(args.local_run_path,
                            f'forget_scores_[{args.forget_score_epochs[0]}-{args.forget_score_epochs[1]}].pkl')
        with open(fs_path, 'rb') as f:
            fs_dict = pickle.load(f)

        fs_by_indices = torch.tensor(sorted(list(zip(fs_dict['indices'],fs_dict['forgetting counts']))))[:,1].float()
        aus_score_data = fs_by_indices
        reversed=False

    elif args.aus_score == 'conf_score':
        aus_score_data = stats_dict['list_conf_score']

    elif args.aus_score == 'margin':
        aus_score_data = stats_dict['list_margins']
    
    elif args.aus_score == 'error':
        aus_score_data = stats_dict['list_obj']

    elif args.aus_score == 'acc':
        aus_score_data = (stats_dict['list_obj'] == stats_dict['list_labels']).float()
        
    else:
        raise NotImplementedError(f'Auto-upsample method: {args.aus_score} is not implemented.')
    
    return aus_score_data, reversed

def select_by_clustering(args, stats_dict, num_examples, aus_score_data, reversed):
    list_acts = stats_dict['list_acts']
    list_logits = stats_dict['list_logits']
    list_labels = stats_dict['list_labels']
    y_kmeans_by_class = []
    orig_idx_by_class = []
    
    for c in tqdm(range(args.num_classes), desc='Clustering for aus'):
        kmeans = KMeans(n_clusters=args.aus_clusters, n_init=10, random_state=args.seed)
        if args.aus_cluster_by == 'activations':
            samples = list_acts[list_labels==c]
        elif args.aus_cluster_by == 'logits':
            samples = list_logits[list_labels==c]
        else:
            raise NotImplementedError('The metric you want to cluster by is not implemented.')

        orig_idx = np.where(list_labels == c)[0]
        kmeans.fit(samples)
        y_kmeans = kmeans.predict(samples)
        y_kmeans_by_class.append(y_kmeans)
        orig_idx_by_class.append(orig_idx)

    orig_idx_by_class_by_cluster = []
    cluster_sizes = torch.tensor([0.] * num_examples)
    for c in range(args.num_classes):
        orig_by_class = []
        for cluster in range(args.aus_clusters):
            indices = [orig_idx_by_class[c][i] for i in np.where(y_kmeans_by_class[c]==cluster)[0]]
            orig_by_class.append(indices)
            cluster_sizes[indices] = len(indices)
        orig_idx_by_class_by_cluster.append(orig_by_class)
    
    cluster_data = []
    for i,c in enumerate(orig_idx_by_class_by_cluster):
        clus = []
        for cluster in range(args.aus_clusters):
            clus.append(aus_score_data[orig_idx_by_class_by_cluster[i][cluster]].mean().item())
        cluster_data.append(clus)

    s,e=args.aus_cluster_range # inclusive    
    upsample_indices = [[indices for _,indices in sorted(zip(class_data,class_indices), reverse=reversed)][s-1:e] \
        for class_data,class_indices in zip(cluster_data,orig_idx_by_class_by_cluster)]
    
    upsample_indices = [i for a in upsample_indices for b in a for i in b]
    
    return upsample_indices

def select_by_threshold(args, aus_score_data, reversed):
    if reversed:
        upsample_indices = np.where(aus_score_data <= args.aus_score_threshold)[0]
    else:
        upsample_indices = np.where(aus_score_data >= args.aus_score_threshold)[0]
    
    return upsample_indices

def select_by_quantile(args, stats_dict, aus_score_data, reversed):
    if reversed:
        aus_score_data = -aus_score_data
    
    if args.aus_quantile_per_class:
        upsample_indices = []
        
        for c in range(args.num_classes):
            class_indices = np.where(stats_dict['list_labels'] == c)[0]
            aus_score_data_by_class = aus_score_data[class_indices]
            lower_threshold = np.quantile(aus_score_data_by_class, q=args.aus_score_quantile_range[0])
            upper_threshold = np.quantile(aus_score_data_by_class, q=args.aus_score_quantile_range[1])
            upsample_indices_by_class = class_indices[np.where((aus_score_data_by_class >= lower_threshold) & (aus_score_data_by_class <= upper_threshold))[0]]
            upsample_indices.append(upsample_indices_by_class)
        
        upsample_indices = np.concatenate(upsample_indices)
    else:  
        lower_threshold = np.quantile(aus_score_data, q=args.aus_score_quantile_range[0])
        upper_threshold = np.quantile(aus_score_data, q=args.aus_score_quantile_range[1])
        upsample_indices = np.where((aus_score_data >= lower_threshold) & (aus_score_data <= upper_threshold))[0]
    
    return upsample_indices

def aus(args, model, criterion, train_loader):
    print(f'Auto-upsampling using: {args.aus_score}.')
    print(f'Clustering using: {args.aus_cluster_by}.')

    torch.save({'last': model.state_dict()}, os.path.join(args.ckpt_path, f'epochs={args.aus_epochs}_aus.pt'))
    print(f'\nSaved model before auto-upsampling.\n')
    
    stats_dict = get_aus_stats(args, model, criterion, train_loader)
    
    shuffled_to_orig = {k:v for k,v in enumerate(stats_dict['list_orig_idx'])}
    
    # Get score 
    aus_score_data, reversed = calculate_score(args, stats_dict)
    
    if args.aus_method == 'clustering':
        upsample_indices = select_by_clustering(args, stats_dict, len(train_loader.dataset), aus_score_data, reversed)
    elif args.aus_method == 'threshold':
        upsample_indices = select_by_threshold(args, aus_score_data, reversed)
    elif args.aus_method == 'quantile':
        upsample_indices = select_by_quantile(args, stats_dict, aus_score_data, reversed)
    else:
        raise NotImplementedError(f'Auto upsample by {args.aus_method} is not implemented.')

    upsample_indices = [shuffled_to_orig[i].item() for i in upsample_indices]

    upsample_indices *= args.aus_weight-1
    if args.aus_shuffle_after:
        upsample_indices = np.random.permutation(upsample_indices)

    upsample_path = os.path.join(args.local_run_path, f'aus_indices.pt')
    torch.save(torch.tensor(upsample_indices), upsample_path)
    print('Upsampling indices saved.')

    args.us_idx_path = upsample_path
    # Turn on upsampling
    args.us = True
    train_loader, _, _, _ = get_dataloaders(args)

    print(f'Auto-upsampling complete. Added {len(upsample_indices)} examples.')

    return train_loader