import numpy as np
# from src.evaluation.evaluation_pipeline.evaluate_method import *
from src.evaluation.evaluation_pipeline.evaluate_realizations import *
from src.evaluation.aux.load_results import *


import matplotlib.pyplot as plt
import os
import argparse
import matplotlib as mpl
import seaborn as sns
import pandas as pd
import config as config

my_pal = config.COLOR
time = "_Date-2022-05-25_Time-03-55"


'''this file plots figures in Appendix I.1.2 Query strategies ablation comparison'''

parser = argparse.ArgumentParser( description='task2 for plotting figure in Appendix I.1.2 Query strategies ablation comparison ')
parser.add_argument('-t', default="time", type=str,
                    help='please input time of your .txt file generated by .sh bash file. For example, _Date-2022-05-18_Time-16-45')

args = parser.parse_args()
time = args.t

if time == "time":
    print("please input time of your .txt file generated by .sh bash file. For example, _Date-2022-05-18_Time-16-45")
    exit()

# rename macros
n_RS = config.n_RS
n_Oracle = config.n_CAMS_best_policy
n_QBC = config.n_qbc
n_IWAL = config.n_iwal
n_MP = config.n_mp
n_CQBC = config.n_contextual_qbc
n_CIWAL = config.n_contextual_iwal
n_CAMS = config.n_CAMS_identity
n_test = config.n_CAMS_test

n_entropy = "entropy"
n_variance = "variance"
n_random = "random"


def rename_method_list(methods):
    arr = []
    for item in methods:
        if item == "rs":
            arr.append(n_RS)
        elif item == "qbc":
            arr.append(n_QBC)
        elif item == "iwal":
            arr.append(n_IWAL)
        elif item == "mp":
            arr.append(n_MP)
        elif item == "contextual_qbc":
            arr.append(n_CQBC)
        elif item == "contextual_iwal":
            arr.append(n_CIWAL)
        elif item == "CAMS_best_policy":
            arr.append(n_Oracle)
        elif item == "CAMS_identity":
            arr.append(n_CAMS)
        elif item == "CAMS_test":
            arr.append(n_test)
        else:
            print("error")
            print(item)
            exit()

    return arr


def find_nearest(array, value):
    array = np.asarray(array)
    idx = (np.abs(array - value)).argmin()
    return array[idx]


def organize_plot(dataset_name, folder_name, budget, my_pal=my_pal):
    path_ = os.getcwd() + "/resources/contextual_data/" + dataset_name

    # Preprocess & load data

    path = os.getcwd() + "/resources/results/" + folder_name + "/"

    file_list = os.listdir(path)
    print(file_list)

    # data output
    data = np.load(path + "data.npz")
    num_reals = data["num_reals"]
    print(num_reals)
    methods = data["methods"]
    budget_raw = data["budgets"]
    experiment_result = np.load(path + "experiment_results_budget" + str(budget) + ".npz")
    eval = np.load(path + "eval_results.npz")
    box_budget = eval["box_budget"]
    box_budget_actual = eval["box_budget_actual"]
    max_method = eval['max_method']
    max_budget_actual = eval['max_budget_actual']
    max_method = rename_method_list(max_method)

    max_bar_query = []
    for item in max_budget_actual:
        # print(item)
        min_bar = 0
        for j in budget_raw:
            if item >= j:
                min_bar = j
        max_bar_query.append(min_bar)

    box_cumulative_loss = eval["box_cumulative_loss"]
    box_method = eval["box_method"]
    box_method = rename_method_list(box_method)
    print(box_method)

    box_df_shading = {"budget": box_budget_actual, "budget_fixed": box_budget, "c_regret": box_cumulative_loss,
                      "method": box_method}
    box_df_shading = pd.DataFrame(box_df_shading)

    reshape_budget = []
    reshape_budget_fixed = []

    for index, row in box_df_shading.iterrows():
        reshape_budget.append(row['budget'])
        budget_w_max = np.concatenate((budget_raw, [max_budget_actual[max_method.index(row['method'])]]))
        round_value = find_nearest(budget_w_max, row['budget'])

        if round_value == max_bar_query[max_method.index(row['method'])]:
            reshape_budget_fixed.append(max_budget_actual[max_method.index(row['method'])])
        else:
            reshape_budget_fixed.append(round_value)

    box_df_shading = {"budget": reshape_budget, "budget_fixed": reshape_budget_fixed, "c_regret": box_cumulative_loss,
                      "method": box_method}
    box_df_shading = pd.DataFrame(box_df_shading)

    for item in methods:
        for budget_ in budget_raw:
            print(item)
            x = np.where((box_df_shading["method"] == item) & (box_df_shading["budget_fixed"] == budget_))
            y = box_df_shading.loc[x]["budget"].mean()

            box_df_shading.iloc[[x], [box_df_shading.columns.get_loc("budget_fixed")]] = y

    shade_df_2 = box_df_shading.filter(["budget_fixed", "method", "c_regret"], axis=1).drop_duplicates().reset_index(
        drop=True)

    print(box_df_shading)

    # Initialize

    plt.figure(figsize=(10, 10), dpi=300)
    #    sns.set(font_scale = 5)
    line_ = sns.lineplot(x="budget_fixed", y="c_regret", label=n_RS, data=shade_df_2[shade_df_2["method"] == n_RS],
                         color=my_pal[n_RS], ci=63, linewidth=1)
    sns.lineplot(x="budget_fixed", y="c_regret", label=n_variance, data=shade_df_2[shade_df_2["method"] == n_Oracle],
                 color=my_pal[n_variance], ci=63, linewidth=1)
    sns.lineplot(x="budget_fixed", y="c_regret", label=n_QBC, data=shade_df_2[shade_df_2["method"] == n_QBC],
                 color=my_pal[n_QBC], ci=63, linewidth=1)
    sns.lineplot(x="budget_fixed", y="c_regret", label=n_IWAL, data=shade_df_2[shade_df_2["method"] == n_IWAL],
                 color=my_pal[n_IWAL], ci=63, linewidth=1)
    sns.lineplot(x="budget_fixed", y="c_regret", label=n_MP, data=shade_df_2[shade_df_2["method"] == n_MP],
                 color=my_pal[n_MP], ci=63, linewidth=1)
    sns.lineplot(x="budget_fixed", y="c_regret", label=n_CQBC, data=shade_df_2[shade_df_2["method"] == n_CQBC],
                 color=my_pal[n_CQBC], ci=63, linewidth=1)
    sns.lineplot(x="budget_fixed", y="c_regret", label=n_CIWAL, data=shade_df_2[shade_df_2["method"] == n_CIWAL],
                 color=my_pal[n_CIWAL], ci=63, linewidth=1)
    sns.lineplot(x="budget_fixed", y="c_regret", label=n_random, data=shade_df_2[shade_df_2["method"] == n_test],
                 color=my_pal[n_random], ci=63, linewidth=1)
    sns.lineplot(x="budget_fixed", y="c_regret", label=n_entropy, data=shade_df_2[shade_df_2["method"] == n_CAMS],
                 color=my_pal[n_entropy], ci=63, linewidth=4)

    plt.xticks(fontsize=20)
    plt.yticks(fontsize=20)
    #    plt.title(dataset_name, fontsize=25)
    plt.xlabel("Query cost", fontsize=30)
    plt.ylabel("", fontsize=30)
    # plt.ylabel("Cumulative Loss", fontsize=30)
    plt.legend(loc=2)
    plt.legend(fontsize=21, title=None)
    plt.legend('')
    plt.savefig("./task2/" + dataset_name + "_task2_query_strategies_shade_line.png", bbox_inches='tight',
                pad_inches=0.01)
    plt.savefig("./task2/" + dataset_name + "_task2_query_strategies_shade_line.pdf", bbox_inches='tight',
                pad_inches=0.01)

    # save legend
    fig = plt.figure(figsize=(10, 10), dpi=300)
    handles, labels = line_.get_legend_handles_labels()

    fig.legend(handles, labels, ncol=8, loc='center')
    fig.savefig("./task2/" + 'legend_query_strategy.png', bbox_inches='tight', pad_inches=0)
    fig.savefig("./task2/" + 'legend_query_strategy.pdf', bbox_inches='tight', pad_inches=0)


dataset_name = "VERTEBRAL"
budget = 80
folder_name = "VERTEBRAL_contextual_streamsize80_numreals200" + time + "_which_methods00000011100_policy[0]"
organize_plot(dataset_name, folder_name, budget)

dataset_name = "DRIFT"
budget = 200
folder_name = "drift_contextual_streamsize3000_numreals300" + time + "_which_methods00000011100_policy[1]"
organize_plot(dataset_name, folder_name, budget)

dataset_name = "HIV"
budget = 100
folder_name = "HIV_contextual_streamsize4000_numreals200" + time + "_which_methods00000011100_policy[0]"
organize_plot(dataset_name, folder_name, budget)

dataset_name = "CIFAR10"
budget = 200
folder_name = "cifar_contextual_streamsize10000_numreals10" + time + "_which_methods00000011100_policy[11]"
organize_plot(dataset_name, folder_name, budget)