import os
from tqdm import tqdm
from statistics import mean
import glob
import json
import pickle
import pandas as pd
from scipy.stats import pearsonr
from configs.constants import SYNTHETIC_DATASETS, DATASET_TO_NAME, MODELS
import pdb
import numpy as np
import torch

#######################################
#### INITIAL PROCESSING OF RESULT JSONS
#######################################

def get_ds_name(ds):
    if "_cf_" in ds:
        name_root, name_stem = ds.split("_cf_", 1)
        if ds in DATASET_TO_NAME:
            return DATASET_TO_NAME[name_root] + f", CLIP {name_stem}"
        else:
            return ds
    else:
        if ds in DATASET_TO_NAME:
            return DATASET_TO_NAME[ds]
        else:
            return ds

def dict_to_list(dct, path=[]):
    if isinstance(dct, dict):
        result = []
        for key, value in dct.items():
            new_path = path + [key]
            result.extend(dict_to_list(value, new_path))
        return result
    else:
        return [{"path": path, "value": dct}]

def get_dgm_results(dataset, dgm_root):
    dgm_path = [x for x in glob.glob(os.path.join(dgm_root, f"*{dataset}*/*.txt")) if x.split("/")[-1] != "done.txt"]
    if "clip" in dataset:
        dgm_path = [x for x in dgm_path if "clip" in x]
    else:
        dgm_path = [x for x in dgm_path if "clip" not in x]
    try:
        assert len(dgm_path) == 1
    except:
        pdb.set_trace()
    with open(dgm_path[0], "r") as f:
        dgm_output = f.read()
    return {y.split(": ")[0]:float(y.split(": ")[1]) for y in dgm_output.split(" \n")[:-1]}

def get_knn_results(root, dgm=True, dgm_root="", models=MODELS):
    json_paths = glob.glob(os.path.join(root, "**/*.json"), recursive=True)
    knn_paths = [p for p in json_paths if "knn" in p]
    datasets = ['_'.join(ds.split("/")) for ds in SYNTHETIC_DATASETS]
    if dgm:
        dgm_results = {ds: get_dgm_results(ds, dgm_root) for ds in datasets}

    knn_results_avg = []
    knn_results_per_class = []
    
    for p in tqdm(knn_paths):
        dataset = p.split("/")[-1].split("results_")[1].split(".json")[0]
        clip_thresh = None
        folder_name = p.split("/")[-2]
        if "_cf_" in p:
            clip_thresh = folder_name.split("_cf_")[-1]

        with open(p, "r") as f:
            results = json.load(f)

        results_list = dict_to_list(results)
        for dct in results_list:
            # pdb.set_trace()
            row = {}

            if 'metadata' in dct['path'] or 'best_param' in dct['path']:
                continue

            if len(dct['path']) == 3:
                order = ["model", "mix", "mode"]
            elif len(dct['path']) == 5:
                order = ["model", "mix", "mode", "acc_type", "eval_ds"]
            elif len(dct['path']) == 6:
                order = ["model", "mix", "mode", "acc_type", "eval_ds", "class"]
            else:
                print(dct['path'])
                raise ValueError("Unknown format")

            if dct['path'][0] not in models:
                continue

            for label, item in zip(order, dct['path']):
                try:
                    row[label] = eval(item)
                except:
                    row[label] = item

            row["acc"] = dct["value"]
            row["dataset"] = get_ds_name(dataset)
            if clip_thresh is not None:
                row["clip_thresh"] = clip_thresh
            else:
                row["clip_thresh"] = 0.0
            if dgm:
                row.update(dgm_results[dataset])

            if "acc_type" in row and row["acc_type"] == "per_class":
                knn_results_per_class.append(row)
            elif "acc_type" in row and row["acc_type"] == "cf":
                continue
            else:
                knn_results_avg.append(row)

    return pd.DataFrame(knn_results_avg), pd.DataFrame(knn_results_per_class)

def get_lgbfs_results(root, dgm=True, dgm_root="", models=MODELS):
    json_paths = glob.glob(os.path.join(root, "**/*.json"), recursive=True)
    lgbfs_paths = [p for p in json_paths if "lgbfs" in p]
    datasets = ['_'.join(ds.split("/")) for ds in SYNTHETIC_DATASETS]
    if dgm:
        dgm_results = {ds: get_dgm_results(ds, dgm_root) for ds in datasets}

    lgbfs_results_avg = []
    lgbfs_results_per_class = []
    for p in tqdm(lgbfs_paths):
        dataset = p.split("/")[-1].split("results_")[1].split(".json")[0]
        clip_thresh = None
        folder_name = p.split("/")[-2]
        if "_cf_" in p:
            clip_thresh = folder_name.split("_cf_")[-1]
        
        with open(p, "r") as f:
            results = json.load(f)
            
        results_list = dict_to_list(results)
        for dct in results_list:
            row = {}
            unknown_label = 0

            if 'metadata' in dct['path'] or 'best_param' in dct['path']:
                continue
            
            if len(dct['path']) == 3:
                order = ["model", "mix", "mode"]
            elif len(dct['path']) == 4:
                order = ["model", "mix", "mode", "eval_ds"]
            elif len(dct['path']) == 5:
                order = ["model", "mix", "mode", "acc_type", "eval_ds"]
            elif len(dct['path']) == 6:
                order = ["model", "mix", "mode", "acc_type", "eval_ds", "class"]
            else:
                print(dct['path'])
                raise ValueError("Unknown format")

            if dct['path'][0] not in models:
                continue

            for label, item in zip(order, dct['path']):
                try:
                    row[label] = eval(item)
                except:
                    row[label] = item

            row["acc"] = dct["value"]
            row["dataset"] = get_ds_name(dataset)

            if clip_thresh is not None:
                row["clip_thresh"] = clip_thresh
            else:
                row["clip_thresh"] = 0.0

            if dgm:
                row.update(dgm_results[dataset])

            if "acc_type" in row and row["acc_type"] == "per_class":
                lgbfs_results_per_class.append(row)
            else:
                lgbfs_results_avg.append(row)

    return pd.DataFrame(lgbfs_results_avg), pd.DataFrame(lgbfs_results_per_class)



###########################################
#### UTILITIES FOR OPENING & FILTERING CSVS
##########################################
def filter_max_k(df):
    max_k_ids = df[df['mode'] == "synthetic_val"].groupby(["dataset", "model", "mix"])['acc'].idxmax()
    val_df_best_k = df.loc[max_k_ids]
    merge_cols = ['dataset', 'model', 'mix', 'k']
    val_df_best_k = val_df_best_k.rename(columns={c: f"{c}_val" for c in val_df_best_k.columns if c not in merge_cols})

    test_best_k = df.merge(val_df_best_k, on=merge_cols, how='inner')
    test_best_k = test_best_k[test_best_k['mode'] == "synthetic_test"]
    return test_best_k.drop(
        columns=[col for col in test_best_k.columns if col in val_df_best_k.columns and col not in merge_cols])

def filter_max_k_pc(df, ref_df):
    merge_cols = ['model', 'mix', 'dataset', 'k']
    ref_df_to_merge = ref_df.rename(columns={c: f"{c}_val" for c in ref_df.columns if c not in merge_cols})
    df = df.merge(ref_df_to_merge, on=merge_cols, how='inner')
    return df.drop(columns=[col for col in df.columns if
                            col in ref_df_to_merge.columns and col not in merge_cols]).drop_duplicates()


def get_global(dataset_name):
    try:
        return dataset_name.split(",")[0].split("_")[0]
    except:
        return np.nan

def get_cfg(dataset_name):
    try:
        return float(dataset_name.split("cfg ")[1])
    except:
        return np.nan


def initial_processing(df):
    df['global'] = df['dataset'].apply(lambda x: get_global(x))
    df['cfg'] = df['dataset'].apply(lambda x: get_cfg(x))
    synthetic, real = df[df['mix'] != 0.00], df[df['mix'] == 0.00]
    return synthetic, real
    

def filter_knn(knn_df, knn_df_pc=None, pc=False):
    synthetic_knn, real_knn = initial_processing(knn_df)
    # synthetic_knn = filter_max_k(synthetic_knn)
    # real_knn = filter_max_k(real_knn)
    
    if pc:
        synthetic_knn_pc, real_knn_pc = initial_processing(knn_df_pc)
        print("Filtering for best k")
        # synthetic_knn_pc = filter_max_k_pc(synthetic_knn_pc, synthetic_knn)
        # real_knn_pc = filter_max_k_pc(real_knn_pc, real_knn)
        return synthetic_knn, real_knn, synthetic_knn_pc, real_knn_pc
    else:
        return synthetic_knn, real_knn

def filter_lgbfs(lgbfs_df, lgbfs_df_pc=None, pc=False):
    synthetic_lgbfs, real_lgbfs = initial_processing(lgbfs_df)

    if pc:
        synthetic_lgbfs_pc, real_lgbfs_pc = initial_processing(lgbfs_df_pc)
        return synthetic_lgbfs, real_lgbfs, synthetic_lgbfs_pc, real_lgbfs_pc
    else:
        return synthetic_lgbfs, real_lgbfs










