import os
import json
import torch
import pickle
import shutil
import open_clip
import pandas as pd
from tqdm import tqdm
from models import LogisticRegressionT
from sklearn.linear_model import LogisticRegression

import random
random.seed(42)


def load_classifier_list(model_path, question_type, input_dim=768, acc_threshold=0, number_of_features=768):
    classifier_list = {}
    subfolder = model_path.split("/")[-2]
    with open(f"../data/{subfolder}/questions/{question_type}.txt", "r") as f:
        questions = f.read().strip().split("\n")
    
    if question_type == "random":
        for q in questions:
            classifier_list[q] = [LogisticRegressionT(input_dim, 1), 0.5]
    else:
        for q in questions:
            model_full_path = f"{model_path}/{q}/{q}.pt"
            # check existence of the model
            if not os.path.exists(model_full_path):
                print(f"Model for {q} does not exist")
                continue
            curr_model = LogisticRegressionT(input_dim, 1)
            curr_model.load_state_dict(torch.load(model_full_path))

            with open(f"{model_path}/{q}/{q}_results.txt", "r") as f:
                results = f.read().strip().split(",")
                val_acc = float(results[-1].strip())

            if val_acc >= acc_threshold:
                classifier_list[q] = (curr_model, val_acc)
    
    # sort the classifiers based on the accuracy
    classifier_list = {k: v for k, v in sorted(classifier_list.items(), key=lambda item: item[1][1], reverse=True)}
    if number_of_features < len(classifier_list):
        # pick top number_of_features classifiers with highest accuracy
        classifier_list = {k: classifier_list[k] for k in list(classifier_list.keys())[:number_of_features]}
    return classifier_list


def load_classifier_list_sklearn(model_path, question_type, input_dim=768, acc_threshold=0, number_of_features=768):
    classifier_list = {}
    subfolder = model_path.split("/")[-2]
    with open(f"../data/{subfolder}/questions/{question_type}.txt", "r") as f:
        questions = f.read().strip().split("\n")
    
    if question_type == "random":
        for q in questions:
            classifier_list[q] = [LogisticRegression(max_iter=1000), 0.5]
    else:
        for q in questions:
            model_full_path = f"{model_path}/{q}/{q}.p"
            # check existence of the model
            if not os.path.exists(model_full_path):
                print(f"Model for {q} does not exist")
                continue
            curr_model = pickle.load(open(model_full_path, "rb"))

            with open(f"{model_path}/{q}/{q}_results.txt", "r") as f:
                results = f.read().strip().split(",")
                val_acc = float(results[-1].strip())

            if val_acc >= acc_threshold:
                classifier_list[q] = (curr_model, val_acc)
    
    # sort the classifiers based on the accuracy
    classifier_list = {k: v for k, v in sorted(classifier_list.items(), key=lambda item: item[1][1], reverse=True)}
    if number_of_features < len(classifier_list):
        # pick top number_of_features classifiers with highest accuracy
        classifier_list = {k: classifier_list[k] for k in list(classifier_list.keys())[:number_of_features]}
    return classifier_list


def load_clip_model(model_name):
    if model_name == "whyxrayclip":
        clip_model, _, preprocess = open_clip.create_model_and_transforms('ViT-L-14', pretrained="../data/mimic_cxr/clip_model/whyxrayclip.pt")
        tokenizer = open_clip.get_tokenizer('ViT-L-14')
    elif model_name == "whylesionclip":
        clip_model, _, preprocess = open_clip.create_model_and_transforms("ViT-L-14", pretrained="../data/isic/clip_model/whylesionclip.pt")
        tokenizer = open_clip.get_tokenizer('ViT-L-14')
    elif model_name == "openclip":
        clip_model, _, preprocess = open_clip.create_model_and_transforms("ViT-L-14", pretrained="laion2b_s32b_b82k")
        tokenizer = open_clip.get_tokenizer('ViT-L-14')
    elif model_name == "openclip_random":
        clip_model, _, preprocess = open_clip.create_model_and_transforms("ViT-L-14", pretrained=False)
        tokenizer = open_clip.get_tokenizer('ViT-L-14')
    elif model_name == "convnext_random":
        clip_model, _, preprocess = open_clip.create_model_and_transforms("convnext_large_d_320", pretrained=False)
        tokenizer = open_clip.get_tokenizer('convnext_large_d_320')
    else:
        clip_model = None
        tokenizer = None
        preprocess = None
    
    return clip_model, tokenizer, preprocess


def linear_features(X_train, X_val, ood_features, number_of_features):
    selected_indices = random.sample(range(0, X_train.shape[1]), min(number_of_features, X_train.shape[1]))
    df_train_log = pd.DataFrame(X_train[:, selected_indices])
    df_val_log = pd.DataFrame(X_val[:, selected_indices])
    df_ood_log = pd.DataFrame(ood_features[:, selected_indices])

    return df_train_log, df_val_log, df_ood_log


def binary_features(X_train_features, X_val_features, ood_features, classifier_list):
    binary_logits_train = {}
    binary_logits_val = {}
    binary_logits_ood = {}

    for kk in classifier_list.keys():
        lr_model = classifier_list[kk][0]
        lr_model.eval()

        binary_logits_train[kk] = lr_model(X_train_features).cpu().detach().numpy().flatten()
        binary_logits_val[kk] = lr_model(X_val_features).cpu().detach().numpy().flatten()
        binary_logits_ood[kk] = lr_model(ood_features).cpu().detach().numpy().flatten()

    df_train_log = pd.DataFrame.from_dict(binary_logits_train)
    df_val_log = pd.DataFrame.from_dict(binary_logits_val)
    df_ood_log = pd.DataFrame.from_dict(binary_logits_ood)

    return df_train_log, df_val_log, df_ood_log


def binary_features_sklearn(X_train_features, X_val_features, ood_features, classifier_list):
    binary_logits_train = {}
    binary_logits_val = {}
    binary_logits_ood = {}

    for kk in classifier_list.keys():
        lr_model = classifier_list[kk][0]

        binary_logits_train[kk] = lr_model.predict_proba(X_train_features)[:, 1]
        binary_logits_val[kk] = lr_model.predict_proba(X_val_features)[:, 1]
        binary_logits_ood[kk] = lr_model.predict_proba(ood_features)[:, 1]

    df_train_log = pd.DataFrame.from_dict(binary_logits_train)
    df_val_log = pd.DataFrame.from_dict(binary_logits_val)
    df_ood_log = pd.DataFrame.from_dict(binary_logits_ood)
    
    return df_train_log, df_val_log, df_ood_log


def dot_product_features(X_train, X_val, ood_features, classifier_list, clip_model, tokenizer):
    prompt_list = list(classifier_list.keys())
    with torch.no_grad(), torch.cuda.amp.autocast():
        text_features = clip_model.encode_text(tokenizer(prompt_list))
        text_features /= text_features.norm(dim=-1, keepdim=True)

    text_features = text_features.numpy()
    product_train = X_train @ text_features.T
    product_val = X_val @ text_features.T
    product_ood = ood_features @ text_features.T

    df_train_log = pd.DataFrame(product_train)
    df_val_log = pd.DataFrame(product_val)
    df_ood_log = pd.DataFrame(product_ood)

    return df_train_log, df_val_log, df_ood_log


def pcbm_features(X_train, X_val, ood_features, classifier_list, clip_model, tokenizer, preprocess, number_of_features):
    # ensemble linear features and dot product features
    df_train_log_lin, df_val_log_lin, df_ood_log_lin = linear_features(X_train, X_val, ood_features, 768)
    df_train_log_dot, df_val_log_dot, df_ood_log_dot = dot_product_features(X_train, X_val, ood_features, classifier_list, clip_model, tokenizer)

    df_train_log = pd.concat([df_train_log_lin, df_train_log_dot], axis=1)
    df_val_log = pd.concat([df_val_log_lin, df_val_log_dot], axis=1)
    df_ood_log = pd.concat([df_ood_log_lin, df_ood_log_dot], axis=1)

    return df_train_log, df_val_log, df_ood_log    


def load_features(feature_path, label2index, shots, normalize, random_seed):
    train_path = f"{feature_path}/train"
    val_path = f"{feature_path}/val"
    ood_path = f"{feature_path}/test"

    tmp_train_list = []
    tmp_val_list = []
    tmp_train_label = []
    tmp_val_label = []
    label_list = list(label2index.keys())
    ood_list = []
    ood_label = []

    for ll in tqdm(label_list):
        # try:
        train_tmp = []
        val_tmp = []
        ood_tmp = []

        # Get train val val
        train_path_list = [f"{train_path}/{ll}/{yp}" for yp in os.listdir(f"{train_path}/{ll}")]
        val_path_list = [f"{val_path}/{ll}/{yp}" for yp in os.listdir(f"{val_path}/{ll}")]
        ood_path_list = [f"{ood_path}/{ll}/{yp}" for yp in os.listdir(f"{ood_path}/{ll}")]

        for yp in train_path_list:
            train_tmp.extend(torch.load(yp))

        tmp_train_label.extend([label2index[ll]] * len(train_tmp))

        for tp in val_path_list:
            val_tmp.extend(torch.load(tp))

        tmp_val_label.extend([label2index[ll]] * len(val_tmp))

        for opp in ood_path_list:
            ood_tmp.extend(torch.load(opp))

        ood_label.extend([label2index[ll]] * len(ood_tmp))
        tmp_train_list.extend(train_tmp)
        tmp_val_list.extend(val_tmp)
        ood_list.extend(ood_tmp)

    df_train_tmp = pd.DataFrame(tmp_train_list)
    df_train_tmp["labels"] = tmp_train_label

    df_val_tmp = pd.DataFrame(tmp_val_list)
    df_val_tmp["labels"] = tmp_val_label

    df_ood = pd.DataFrame(ood_list)
    df_ood["labels"] = ood_label

    df_train= df_train_tmp.sample(frac=1, random_state=random_seed).reset_index(drop=True)
    df_val = df_val_tmp.sample(frac=1, random_state=random_seed).reset_index(drop=True)
    
    if shots != "all":
        # random sample number of shots for each class
        df_train = df_train.groupby('labels').sample(n=int(shots), random_state=random_seed).reset_index(drop=True)
        # shuffle the train set with fixed seed
        df_train = df_train.sample(frac=1, random_state=random_seed).reset_index(drop=True)

    X_train = torch.tensor(df_train[list(df_train.columns)[:-1]].values).float()
    y_train = df_train["labels"].values

    X_val = torch.tensor(df_val[list(df_val.columns)[:-1]].values).float()
    y_val = df_val["labels"].values

    ood_features = torch.tensor(df_ood[list(df_ood.columns)[:-1]].values).float()

    if normalize:
        print("Normalizing the features")
        X_train /= X_train.norm(dim=-1, keepdim=True)
        X_val /= X_val.norm(dim=-1, keepdim=True)
        ood_features /= ood_features.norm(dim=-1, keepdim=True)
    
    return X_train, y_train, X_val, y_val, ood_features, ood_label


def get_prior_matrix(modality, class_names, questions):
    prior = torch.zeros(len(class_names), len(questions))
    class2question2answer = json.load(open(f"../data/binary_{modality}/questions/class2question2answer.json", "r"))

    for i, c in enumerate(class_names):
        for j, q in enumerate(questions):
            # random assign -1 or 1 to the prior
            # prior[i, j] = 1.0 if random.random() > 0.5 else -1.0
            
            if class2question2answer[c][q] == "yes": prior[i, j] = 1.0
            elif class2question2answer[c][q] == "no": prior[i, j] = -1.0
            elif class2question2answer[c][q] == "unknown": prior[i, j] = 0.0

    return prior


def compute_prior_loss(model):
    model_weights = model.linear.weight
    prior = model.prior
    number_of_weights = model_weights.shape[0] * model_weights.shape[1]

    # apply tanh to weights to map it to [-1, 1]
    model_weights = torch.tanh(model_weights)

    # compute l1 loss between the weights and the prior
    prior_loss = torch.sum(torch.abs(model_weights - prior)) / number_of_weights

    return prior_loss


def map_weights(model, class_names, questions):
    model.eval()
    weights = model.linear.weight
    class2question_weights = {}
    for i, c in enumerate(class_names):
        class2question_weights[c] = {}
        for j, q in enumerate(questions):
            class2question_weights[c][q] = weights[i, j].item()
    
    # sort the weights based on the absolute value
    class2question_weights = {k: {kk: vv for kk, vv in sorted(v.items(), key=lambda item: item[1], reverse=True)} for k, v in class2question_weights.items()}

    return class2question_weights


def analyze_prediction(dataset_name, model, X, y, class_names, questions):
    if "ISIC" in dataset_name: image_dir = "../data/datasets/isic/images/"
    elif "PAD" in dataset_name: image_dir = "../data/datasets/PAD-UFES-20/images/"
    elif dataset_name == "HAM10000": image_dir = "../data/datasets/HAM10000/images/"
    elif dataset_name == "Melanoma": image_dir = "../data/datasets/Melanoma/"
    elif dataset_name == "UWaterloo": image_dir = "../data/datasets/UWaterloo/"
    elif dataset_name == "BCN20000": image_dir = "../data/datasets/isic/images/"
    else: image_dir = "../data/datasets/"

    model.eval()
    weight = model.linear.weight

    class2images = pickle.load(open(f"../data/datasets/{dataset_name}/splits/class2images_test.p", 'rb'))
    all_images = []
    for ll in class_names:
        all_images.extend([f"{image_dir}{img}" for img in class2images[ll]])
    
    # get top-5 questions for each image based on X which is the feature
    ind2question2prob = {}
    for i in range(len(y)):
        features = X[i]
        question2prob = {questions[j]: features[j].item() for j in range(len(questions))}
        question2prob = {k: v for k, v in sorted(question2prob.items(), key=lambda item: item[1], reverse=True)}
        ind2question2prob[i] = question2prob
    
    y_pred = torch.argmax(model(X), dim=1)

    class2wrongly_predicted_indices = {}
    for i in range(len(y)):
        if y[i] != y_pred[i]:
            if y[i].item() not in class2wrongly_predicted_indices:
                class2wrongly_predicted_indices[y[i].item()] = [i]
            else:
                class2wrongly_predicted_indices[y[i].item()].append(i)
    
    wrongly_predicted_indices = []
    for c in class2wrongly_predicted_indices:
        wrongly_predicted_indices.extend(random.sample(class2wrongly_predicted_indices[c], min(5, len(class2wrongly_predicted_indices[c]))))
    
    prediction_info = {}
    
    # get the contribution to the wrong prediction
    for i in wrongly_predicted_indices:
        prediction_info[i] = {"image_path": all_images[i], "true_label": class_names[y[i]], "predicted_label": class_names[y_pred[i]]}

        # get the top-5 questions with highest grounding from ind2question2prob
        grounding = ind2question2prob[i] # keep 3 decimal points
        prediction_info[i]["grounding"] = {k: round(grounding[k], 3) for k in list(grounding.keys())}

        # get the top-5 questions with the highest weight in the wrong and true class
        question2weight_true = {questions[j]: round(weight[y[i], j].item(), 3) for j in range(len(questions))}
        question2weight_true = {k: v for k, v in sorted(question2weight_true.items(), key=lambda item: item[1], reverse=True)}
        prediction_info[i]["weight_for_true"] = {k: question2weight_true[k] for k in list(question2weight_true.keys())}

        question2weight_pred = {questions[j]: round(weight[y_pred[i], j].item(), 3) for j in range(len(questions))}
        question2weight_pred = {k: v for k, v in sorted(question2weight_pred.items(), key=lambda item: item[1], reverse=True)}
        prediction_info[i]["weight_for_pred"] = {k: question2weight_pred[k] for k in list(question2weight_pred.keys())}

        features = X[i]
        # take the pair-wise multiplication of the feature and the weight
        contribution_true = features * weight[y[i]]
        contribution_pred = features * weight[y_pred[i]]

        # get the index for top-5 contributing questions
        top_5_indices_true = torch.argsort(contribution_true, descending=True)
        top_5_indices_pred = torch.argsort(contribution_pred, descending=True)

        prediction_info[i]["contribution_for_true"] = {questions[j]: {"contribution": round(contribution_true[j].item(), 3), "weight": round(weight[y[i], j].item(), 3), "grounding": round(grounding[questions[j]], 3)} for j in top_5_indices_true}
        prediction_info[i]["contribution_for_pred"] = {questions[j]: {"contribution": round(contribution_pred[j].item(), 3), "weight": round(weight[y_pred[i], j].item(), 3), "grounding": round(grounding[questions[j]], 3)} for j in top_5_indices_pred}
    
    for i in prediction_info:
        # copy image to "../data/predictions/images" folder
        # check if the image is already copied
        if not os.path.exists(f"../data/predictions/images/{prediction_info[i]['image_path'].split('/')[-1]}"):
            shutil.copy(prediction_info[i]["image_path"], f"../data/predictions/images/{prediction_info[i]['image_path'].split('/')[-1]}")
        prediction_info[i]["image_path"] = f"images/{prediction_info[i]['image_path'].split('/')[-1]}"

    return prediction_info


def get_diversity_score(concepts):
    from sentence_transformers import SentenceTransformer, util
    sbert_model = SentenceTransformer('all-mpnet-base-v2', device = "cuda:0")
    
    sentence_embeddings = sbert_model.encode(concepts, convert_to_tensor=True)

    cosine_scores = util.pytorch_cos_sim(sentence_embeddings, sentence_embeddings)
    cosine_distance = 1 - cosine_scores

    # set diagonal to 1
    for i in range(len(cosine_distance)):
        cosine_distance[i][i] = 1
    
    # # get the min of each row
    # min_values = cosine_distance.min(dim=1).values
    
    # get the mean of each row
    min_values = cosine_distance.mean(dim=1)

    return min_values.mean().item()