import json
import copy
import torch
import torch.nn as nn
from torch import optim
import pandas as pd
from argparse import ArgumentParser
from sklearn.metrics import accuracy_score
from sklearn.linear_model import LogisticRegression
from models import MultiClassLogisticRegression, PosthocHybridCBM
from torch.utils.data import DataLoader, TensorDataset
from torchmetrics import Accuracy
from utils import *
import random
random.seed(42)
torch.manual_seed(42)

def get_results_sklearn(df_train_log, y_train, df_val_log, y_val, df_ood_log, y_ood):
    # sklearn logistic regression
    lr_model = LogisticRegression(max_iter=1000, verbose=1, n_jobs=16)
    lr_model.fit(df_train_log, y_train)

    # get accuracy
    val_acc = round(accuracy_score(y_val, lr_model.predict(df_val_log)) * 100, 2)
    ood_acc = round(accuracy_score(y_ood, lr_model.predict(df_ood_log)) * 100, 2)
    average_acc = round((val_acc + ood_acc) / 2, 2)
    gap = round(abs(val_acc - ood_acc), 2)

    return val_acc, ood_acc, gap, average_acc


def train_model_torch(model, train_dataloader, val_dataloader, optimizer, criterion, num_epochs):
    best_val_acc = -float("inf")
    best_model = None
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        for X_batch, y_batch in train_dataloader:
            optimizer.zero_grad()
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch.long())

            if model.apply_prior != False:
                prior_loss = compute_prior_loss(model)
                loss += 1.0 * prior_loss

            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        
        # evaluate the model
        val_acc = evaluate_model_torch(model, val_dataloader)
        print(f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_loss / len(train_dataloader)}, Val Acc: {val_acc}")
            
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model = copy.deepcopy(model)
    
    return best_model


def evaluate_model_torch(model, dataloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for X_batch, y_batch in dataloader:
            outputs = model(X_batch)
            _, predicted = torch.max(outputs.data, 1)
            total += y_batch.size(0)
            correct += (predicted == y_batch).sum().item()
    accuracy = 100 * correct / total
    
    return accuracy


def get_results_torch(mode, dataset_name, modality, label2index, classifier_list, df_train_log, y_train, df_val_log, y_val, df_ood_log, y_ood, add_prior, batch_size, learning_rate, num_epochs):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Convert features and labels to PyTorch tensors
    X_train_torch = torch.tensor(df_train_log.values).float().to(device)
    y_train_torch = torch.tensor(y_train).to(device)
    X_val_torch = torch.tensor(df_val_log.values).float().to(device)
    y_val_torch = torch.tensor(y_val).to(device)
    X_ood_torch = torch.tensor(df_ood_log.values).float().to(device)
    y_ood_torch = torch.tensor(y_ood).to(device)

    # Create DataLoader instances
    train_dataset = TensorDataset(X_train_torch, y_train_torch)
    val_dataset = TensorDataset(X_val_torch, y_val_torch)
    ood_dataset = TensorDataset(X_ood_torch, y_ood_torch)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    ood_loader = DataLoader(ood_dataset, batch_size=batch_size, shuffle=False)

    num_classes = len(torch.unique(y_train_torch))  # Assuming y_train contains all classes
    class_names = list(label2index.keys())
    questions = list(classifier_list.keys())
    print(class_names)

    # Get the prior matrix
    prior = get_prior_matrix(modality, class_names, questions)

    # Define the logistic regression model
    if mode == "pcbm":
        model = PosthocHybridCBM(n_concepts=len(questions), 
                                 n_classes=num_classes, 
                                 n_image_features=X_train_torch.shape[1] - len(questions))
    else:
        model = MultiClassLogisticRegression(num_features=X_train_torch.shape[1], 
                                             num_classes=num_classes, 
                                             prior=prior,
                                             apply_prior=add_prior)
    model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Train the model
    best_model = train_model_torch(model, train_loader, val_loader, optimizer, criterion, num_epochs)

    # Evaluate the model
    val_acc = evaluate_model_torch(best_model, val_loader)
    ood_acc = evaluate_model_torch(best_model, ood_loader)

    average_acc = round((val_acc + ood_acc) / 2, 2)
    gap = round(abs(val_acc - ood_acc), 2)

    # if mode == "binary" and dataset_name == "HAM10000":
    #     # weights = map_weights(best_model, class_names, questions)
    #     # json.dump(weights, open(f"../data/weights/{dataset_name}.json", "w"), indent=4)
    #     prediction_info = analyze_prediction(dataset_name, best_model, X_ood_torch, y_ood_torch, class_names, questions)
    #     json.dump(prediction_info, open(f"../data/predictions/{dataset_name}.json", "w"), indent=4)
    
    return val_acc, ood_acc, gap, average_acc, best_model


def run_classification(modality, dataset_name, model_name, classifier_list, mode, shots, clip_model, tokenizer, number_of_features, normalize, implementaion, add_prior, random_seed=42):
    # Load the features
    label2index = torch.load(f"../data/features/{model_name}/{dataset_name}_label.pt")
    X_train, y_train, X_val, y_val, X_ood, y_ood = load_features(f"../data/features/{model_name}/{dataset_name}", label2index, shots, normalize, random_seed)

    if mode == "binary":
        df_train_log, df_val_log, df_ood_log = binary_features_sklearn(X_train, X_val, X_ood, classifier_list)
    elif mode == "linear_probe":
        df_train_log, df_val_log, df_ood_log = linear_features(X_train, X_val, X_ood, number_of_features)
    elif mode == "dot_product":
        df_train_log, df_val_log, df_ood_log = dot_product_features(X_train, X_val, X_ood, classifier_list, clip_model, tokenizer)
    elif mode == "pcbm":
        df_train_log, df_val_log, df_ood_log = pcbm_features(X_train, X_val, X_ood, classifier_list, clip_model, tokenizer, preprocess, number_of_features)

    print("Train size: ", df_train_log.shape, "Test size: ", df_val_log.shape, "OOD size: ", df_ood_log.shape)

    if implementaion == "sklearn":
        val_acc, ood_acc, gap, average_acc = get_results_sklearn(df_train_log, y_train, df_val_log, y_val, df_ood_log, y_ood)
    elif implementaion == "torch":
        val_acc, ood_acc, gap, average_acc, best_model = get_results_torch(mode, dataset_name, modality, label2index, classifier_list, df_train_log, y_train, df_val_log, y_val, df_ood_log, y_ood, add_prior, batch_size=64, learning_rate=0.001, num_epochs=200)

    print(f"Dataset: {dataset_name}, Mode: {mode}", f"Shots: {shots}", f"Model: {model_name}")
    print(f"Ind Acc: {val_acc}, OOD Acc: {ood_acc}, Gap: {gap}, Average: {average_acc}")
    number_of_features_actual = df_train_log.shape[1]

    return val_acc, ood_acc, gap, average_acc, number_of_features_actual


def ablate_bottleneck_size(modality, dataset_lists, model_name, classifier_list, mode, shots, clip_model, tokenizer, number_of_features, normalize, implementaion, binary_model_path, add_prior, save_suffix):
    results_dict = {}
    for number_of_classifiers in range(15, 165, 15):
        print("number of classifiers: ", number_of_classifiers)
        results_dict[f"{number_of_classifiers}"] = {i: {} for i in range(10)}
        # five runs for each number of classifiers
        for run in range(10):
            if mode == "linear_probe":
                for dataset_name in dataset_lists:
                    ind_acc, out_acc, gap, avg, number_of_features_actual = run_classification(modality, dataset_name, model_name, classifier_list, mode, shots, clip_model, tokenizer, number_of_classifiers, normalize, implementaion, add_prior)
                    results_dict[f"{number_of_classifiers}"][run][dataset_name] = {"ind_acc": ind_acc, "out_acc": out_acc, "gap": gap, "avg": avg}
                    json.dump(results_dict, open(f"../data/bottleneck_size/bottleneck_size_{modality}_{mode}_{shots}.json", "w"), indent=4)
            else:
                # each time randomly select number_of_classifiers
                random.seed(run)
                curr_classifier_names = random.sample(list(classifier_list.keys()), number_of_classifiers)
                curr_classifier_list = {k: classifier_list[k] for k in curr_classifier_names}
                for dataset_name in dataset_lists:
                    ind_acc, out_acc, gap, avg, number_of_features_actual = run_classification(modality, dataset_name, model_name, curr_classifier_list, mode, shots, clip_model, tokenizer, number_of_classifiers, normalize, implementaion, add_prior)
                    results_dict[f"{number_of_classifiers}"][run][dataset_name] = {"ind_acc": ind_acc, "out_acc": out_acc, "gap": gap, "avg": avg}
                    json.dump(results_dict, open(f"../data/bottleneck_size/bottleneck_size_{modality}_{mode}_{question_type}_{shots}.json", "w"), indent=4)


def few_shot(dataset_lists, model_name, classifier_list, mode, shots, clip_model, tokenizer, number_of_features, normalize):
    results_dict = {}
    for shots in ["1", "2", "4", "8", "16", "32", "all"]:
        print("Number of shots: ", shots)
        results_dict[shots] = {i: {} for i in range(5)}
        for run in range(5):
            for dataset_name in dataset_lists:
                ind_acc, out_acc, gap, avg = run_classification(dataset_name, model_name, classifier_list, mode, shots, clip_model, tokenizer, number_of_features, normalize, random_seed=run)
                results_dict[shots][run][dataset_name] = {"ind_acc": ind_acc, "out_acc": out_acc, "gap": gap, "avg": avg}
    json.dump(results_dict, open(f"../data/fewshot_results/{mode}_{question_type}.json", "w"))


def run_all_datasets(modality, dataset_lists, model_name, classifier_list, mode, shots, clip_model, tokenizer, number_of_features, normalize, implementaion, binary_model_path, add_prior, save_suffix):
    results_dict = {}
    for dataset_name in dataset_lists:
        print("Number of classifiers: ", len(classifier_list))
        ind_acc, out_acc, gap, avg, number_of_features_actual = run_classification(modality, dataset_name, model_name, classifier_list, mode, shots, clip_model, tokenizer, number_of_features, normalize, implementaion, add_prior)
        results_dict[dataset_name] = {"ind_acc": ind_acc, "out_acc": out_acc, "gap": gap, "avg": avg}

    # reshape to one row
    csv_df = pd.DataFrame.from_dict(results_dict).T
    # Reshape the data
    reshaped_df = pd.DataFrame(csv_df.values.flatten()).T

    # Create new column names
    new_columns = [f"{row_label}_{col_label}" for row_label in csv_df.index for col_label in csv_df.columns]

    # Assign new column names to reshaped dataframe
    reshaped_df.columns = new_columns
    
    # save as csv use all arguments as file name
    binary_model_path = binary_model_path.split("/")[-1]
    file_name = f"../data/cbm_results_{implementaion}_{save_suffix}/{modality}_{mode}_{model_name}_{question_type}_{shots}_{number_of_features_actual}_{binary_model_path}.csv"

    if add_prior: file_name = file_name.replace(".csv", "_prior.csv")
    
    # creat folder if not exist
    if not os.path.exists(os.path.dirname(file_name)):
        os.makedirs(os.path.dirname(file_name))

    reshaped_df.to_csv(file_name)


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--mode", type=str, default="binary", help="binary, linear_probe, dot_product")
    parser.add_argument("--question_type", type=str, default="xray_prompt_200", help="findings, doctor, gpt4")
    parser.add_argument("--shots", type=str, default="all", help="all, 1, 2, 4, 8, 16, 32, 64")
    parser.add_argument("--model_name", type=str, default="whyxrayclip", help="whyxrayclip, whylesionclip")
    parser.add_argument("--acc_threshold", type=float, default=0, help="accuracy threshold for loading classifiers")
    parser.add_argument("--number_of_features", type=int, default=768, help="number of features to select for linear probe")
    parser.add_argument("--eval_type", type=str, default="normal", help="normal, ablate_bottleneck_size")
    parser.add_argument("--normalize", type=str, default="True", help="normalize the features")
    parser.add_argument("--modality", type=str, default="xray", help="xray, natural, skin")
    parser.add_argument("--binary_model_path", type=str, default="../data/binary_xray/models_whyxrayclip_t5_1000_sklearn", help="path to the binary models")
    parser.add_argument("--input_dim", type=int, default=768, help="input dimension for the binary classifiers")
    parser.add_argument("--implementation", type=str, default="torch", help="sklearn, torch")
    parser.add_argument("--add_prior", type=str, default="False", help="add prior to the model")
    parser.add_argument("--save_suffix", type=str, default="", help="add suffix to the save folder")

    args = parser.parse_args()

    mode = args.mode
    question_type = args.question_type
    shots = args.shots
    acc_threshold = args.acc_threshold
    model_name = args.model_name
    number_of_features = args.number_of_features
    eval_type = args.eval_type
    normalize = False if args.normalize == "False" else True
    modality = args.modality
    binary_model_path = args.binary_model_path
    input_dim = args.input_dim
    implementation = args.implementation
    add_prior = False if args.add_prior == "False" else True
    save_suffix = args.save_suffix

    # Load clip model
    clip_model, tokenizer, preprocess = load_clip_model(model_name)

    # Load classifiers
    classifier_list = load_classifier_list_sklearn(binary_model_path, question_type, input_dim, acc_threshold, number_of_features)
    binary_accuracies = [classifier_list[k][1] for k in classifier_list.keys()]
    print(f"Number of classifiers: {len(classifier_list)}, Mean Acc: {round(sum(binary_accuracies) / len(binary_accuracies), 5)}")

    diversity = get_diversity_score(list(classifier_list.keys()))
    print(question_type)
    print(f"Diversity: {diversity}")

    # Load datasets
    if modality == "xray":
        dataset_lists = ["NIH-gender", "NIH-age", "NIH-pos", "CheXpert-race", "NIH-CheXpert", "pneumonia", "COVID-QU", "NIH-CXR", "open-i", "vindr-cxr"]
    elif modality == "natural":
        dataset_lists = ["STL-10", "imagenet10", "CIFAR10", "food", "flower"]
    elif modality == "skin":
        dataset_lists = ["ISIC-gender", "ISIC-age", "ISIC-site", "ISIC-color", "ISIC-hospital", "HAM10000", "BCN20000", "PAD-UFES-20", "Melanoma", "UWaterloo"]
        
    if eval_type == "ablate_bottleneck_size":
        ablate_bottleneck_size(modality, dataset_lists, model_name, classifier_list, mode, shots, clip_model, tokenizer, number_of_features, normalize, implementation, binary_model_path, add_prior, save_suffix)

    elif eval_type == "fewshot":
        few_shot(dataset_lists, model_name, classifier_list, mode, shots, clip_model, tokenizer, number_of_features, normalize)

    else:
        run_all_datasets(modality, dataset_lists, model_name, classifier_list, mode, shots, clip_model, tokenizer, number_of_features, normalize, implementation, binary_model_path, add_prior, save_suffix)