import os
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from argparse import ArgumentParser
import argparse

import matplotlib.pyplot as plt
import json

# matplotlibrc params to set for better, bigger, clear plots
SMALLER_SIZE = 8
MEDIUM_SIZE = 10
BIGGER_SIZE = 15

plt.rc('font', size=BIGGER_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=BIGGER_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=BIGGER_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=BIGGER_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=BIGGER_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=BIGGER_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

# take in the directory path of numpy array files
parser = argparse.ArgumentParser(description='Plot some plots')
parser.add_argument('--outdir', type=str, help='dir path for file name')
parser.add_argument('--plot_file_name', type=str, help='plot file name')
args = parser.parse_args()

if not os.path.exists(args.outdir):
    os.makedirs(args.outdir)

# vanilla
vanilla_seed_1_path = "logs/mask_idea/runs/cifar10/base_classifier_cifar_resnet110/num_queries_1/mask_model_None/pad_size_8/num_image_locations_random/background_nature/mask_init_None/budget_split_None/first_query_budget_frac_1.0/first_query_with_mask_False/linf_pert_0.00902/average_queries_False/mask_output_None/seed_1/vanilla/train_log.txt"
vanilla_seed_2_path = "logs/mask_idea/runs/cifar10/base_classifier_cifar_resnet110/num_queries_1/mask_model_None/pad_size_8/num_image_locations_random/background_nature/mask_init_None/budget_split_None/first_query_budget_frac_1.0/first_query_with_mask_False/linf_pert_0.00902/average_queries_False/mask_output_None/seed_2/vanilla/train_log.txt"
vanilla_seed_3_path = "logs/mask_idea/runs/cifar10/base_classifier_cifar_resnet110/num_queries_1/mask_model_None/pad_size_8/num_image_locations_random/background_nature/mask_init_None/budget_split_None/first_query_budget_frac_1.0/first_query_with_mask_False/linf_pert_0.00902/average_queries_False/mask_output_None/seed_3/vanilla/train_log.txt"

df_vanilla_seed_1 = pd.read_csv(vanilla_seed_1_path, delimiter=" \t ")['test_acc']
df_vanilla_seed_2 = pd.read_csv(vanilla_seed_2_path, delimiter=" \t ")['test_acc']
df_vanilla_seed_3 = pd.read_csv(vanilla_seed_3_path, delimiter=" \t ")['test_acc']

x_values = range(len(df_vanilla_seed_1))
stacked_values = np.vstack((df_vanilla_seed_1.values, df_vanilla_seed_2.values, df_vanilla_seed_3.values)) 
y_min = stacked_values.min(axis=0)
y_max = stacked_values.max(axis=0)
vanilla_mean = stacked_values.mean(axis=0)

plt.plot(x_values, vanilla_mean, label='vanilla', color='blue')
plt.fill_between(x_values, y_min, y_max, color='blue', alpha=0.5)

# static mask
static_mask_seed_1_path = "logs/mask_idea/runs/cifar10/base_classifier_cifar_resnet110/num_queries_1/mask_model_None/pad_size_8/num_image_locations_random/background_nature/mask_init_identity/budget_split_None/first_query_budget_frac_1.0/first_query_with_mask_True/linf_pert_0.00902/average_queries_False/mask_output_None/seed_1/static_learnt_mask_1_query/train_log.txt"
static_mask_seed_2_path = "logs/mask_idea/runs/cifar10/base_classifier_cifar_resnet110/num_queries_1/mask_model_None/pad_size_8/num_image_locations_random/background_nature/mask_init_identity/budget_split_None/first_query_budget_frac_1.0/first_query_with_mask_True/linf_pert_0.00902/average_queries_False/mask_output_None/seed_2/static_learnt_mask_1_query/train_log.txt"
static_mask_seed_3_path = "logs/mask_idea/runs/cifar10/base_classifier_cifar_resnet110/num_queries_1/mask_model_None/pad_size_8/num_image_locations_random/background_nature/mask_init_identity/budget_split_None/first_query_budget_frac_1.0/first_query_with_mask_True/linf_pert_0.00902/average_queries_False/mask_output_None/seed_3/static_learnt_mask_1_query/train_log.txt"

df_static_mask_seed_1 = pd.read_csv(static_mask_seed_1_path, delimiter=" \t ")['test_acc']
df_static_mask_seed_2 = pd.read_csv(static_mask_seed_2_path, delimiter=" \t ")['test_acc']
df_static_mask_seed_3 = pd.read_csv(static_mask_seed_3_path, delimiter=" \t ")['test_acc']

x_values = range(len(df_static_mask_seed_1))
stacked_values = np.vstack((df_static_mask_seed_1.values, df_static_mask_seed_2.values, df_static_mask_seed_3.values)) 
y_min = stacked_values.min(axis=0)
y_max = stacked_values.max(axis=0)
static_mask_mean = stacked_values.mean(axis=0)

plt.plot(x_values, static_mask_mean, label='static_mask', color='orange')
plt.fill_between(x_values, y_min, y_max, color='orange', alpha=0.5)

# adaptive
adaptive_seed_1_path = "logs/mask_idea/runs/cifar10/base_classifier_cifar_resnet110/num_queries_2/mask_model_modified_resnet/pad_size_8/num_image_locations_random/background_nature/mask_init_random/budget_split_learnt/first_query_budget_frac_0.5/first_query_with_mask_False/linf_pert_0.00902/average_queries_True/mask_output_sigmoid/seed_1/fqbf_lr_0.0001_mom_0.9_wd_0_scheduler_off_t_scheduler_on/train_log.txt"
adaptive_seed_2_path = "logs/mask_idea/runs/cifar10/base_classifier_cifar_resnet110/num_queries_2/mask_model_modified_resnet/pad_size_8/num_image_locations_random/background_nature/mask_init_random/budget_split_learnt/first_query_budget_frac_0.5/first_query_with_mask_False/linf_pert_0.00902/average_queries_True/mask_output_sigmoid/seed_2/fqbf_lr_0.0001_mom_0.9_wd_0_scheduler_off_t_scheduler_on/train_log.txt"
adaptive_seed_3_path = "logs/mask_idea/runs/cifar10/base_classifier_cifar_resnet110/num_queries_2/mask_model_modified_resnet/pad_size_8/num_image_locations_random/background_nature/mask_init_random/budget_split_learnt/first_query_budget_frac_0.5/first_query_with_mask_False/linf_pert_0.00902/average_queries_True/mask_output_sigmoid/seed_3/fqbf_lr_0.0001_mom_0.9_wd_0_scheduler_off_t_scheduler_on/train_log.txt"

df_adaptive_seed_1 = pd.read_csv(adaptive_seed_1_path, delimiter=" \t ")['test_acc']
df_adaptive_seed_2 = pd.read_csv(adaptive_seed_2_path, delimiter=" \t ")['test_acc']
df_adaptive_seed_3 = pd.read_csv(adaptive_seed_3_path, delimiter=" \t ")['test_acc']

x_values = range(len(df_adaptive_seed_1))
stacked_values = np.vstack((df_adaptive_seed_1.values, df_adaptive_seed_2.values, df_adaptive_seed_3.values)) 
y_min = stacked_values.min(axis=0)
y_max = stacked_values.max(axis=0)
adaptive_mean = stacked_values.mean(axis=0)

plt.plot(x_values, adaptive_mean, label='adaptive', color='green')
plt.fill_between(x_values, y_min, y_max, color='green', alpha=0.5)

# set other configs of plot
plt.xticks(np.arange(0, len(df_adaptive_seed_1)+1, 10))
plt.xlabel("epochs")
plt.legend()
# plt.title("random")
plt.savefig(os.path.join(args.outdir, args.plot_file_name))
plt.close()
