import math
import sys

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np


def generate_hued_plot(df,
                       metric='ratio_boxes_cleared',
                       hue='seed',
                       title='',
                       interval=None,
                       path=None
                       ):
    sns.set_context('paper')
    sns.lineplot(x='epoch', y=metric, data=df, hue=hue, ci=interval).set(title=title)
    plt.legend(loc='lower right')

    if path is not None:
        plt.savefig(path)
        print('Saving plot to', path)
    else:
        plt.show()
    plt.clf()


def hued_plot_custom_ci(df,
                        metric='ratio_boxes_cleared',
                        hue='Method',
                        title='',
                        path=None
                        ):
    sns.set_context('paper')
    colors = sns.color_palette()
    # Params tweaked to look the same as seaborn lineplot
    alphas = {'agg': 1, 'ci_edge': 0.05, 'ci_fill': 0.2}
    linewidths = {'agg': 1.5, 'ci_edge': 0.5}

    hue_values = df[hue].unique()
    fig, ax = plt.subplots()

    for hue_index, hue_value in enumerate(hue_values):
        hue_df = df.loc[df[hue] == hue_value]

        agg_df = hue_df.groupby('epoch').agg(
            **{metric: pd.NamedAgg(column=metric, aggfunc=np.mean),
               'ci': pd.NamedAgg(column=metric, aggfunc=lambda x: np.std(x) / math.sqrt(len(x)) * 1.96)}  # 95% CI
        ).reset_index()
        agg_df[hue] = hue_value

        ax.plot(agg_df['epoch'], agg_df[metric], color=colors[hue_index],
                alpha=alphas['agg'], linewidth=linewidths['agg'], label=hue_value)  # Main line
        ax.plot(agg_df['epoch'], agg_df[metric] - agg_df['ci'],
                color=colors[hue_index], alpha=alphas['ci_edge'], linewidth=linewidths['ci_edge'])  # Lower CI
        ax.plot(agg_df['epoch'], agg_df[metric] + agg_df['ci'],
                color=colors[hue_index], alpha=alphas['ci_edge'], linewidth=linewidths['ci_edge'])  # Upper CI
        ax.fill_between(agg_df['epoch'], agg_df[metric] - agg_df['ci'], agg_df[metric] + agg_df['ci'],
                        alpha=alphas['ci_fill'], color=colors[hue_index])  # CI range

    plt.title(title)
    plt.xlabel('Epoch')
    plt.ylabel(metric)
    plt.legend(loc='lower right')

    if path is not None:
        plt.savefig(path)
        print('Saving plot to', path)
    else:
        plt.show()
    plt.clf()

def generate_line_plot(df,
                       metric='ratio_boxes_cleared'
                       ):
    sns.set_context('paper')
    sns.lineplot(x='epoch', y=metric, data=df, palette=sns.color_palette())
    plt.show()


def get_max_results(df: pd.DataFrame):
    return df.cummax()


def get_min_results(df: pd.DataFrame):
    return df.cummin()


def get_sorted_results(df: pd.DataFrame,
                       hue_name='seed',
                       metric='ratio_boxes_cleared'
                       ):
    d = {'epoch': [], hue_name: [], metric: []}
    for index, row in df.iterrows():
        for key in row.keys():
            if metric in key and "MAX" not in key and "MIN" not in key:
                hue = key.split(" ")[1]
                d['epoch'].append(row['epoch'])
                d[hue_name].append(hue)
                d[metric].append(row[key])

    new_df = pd.DataFrame(d)
    new_df.set_index('epoch')

    return new_df


def get_mean_var_values_for_df(df: pd.DataFrame, method='Unique IDs', metric='Reward', run_identifier_name='seed'):
    # Filter by method
    method_df = df.loc[df['Method'] == method]

    # Get per-run performance
    run_performances = []
    run_identifiers = method_df[run_identifier_name].unique()
    for run_identifier in run_identifiers:
        run_df = method_df.loc[method_df[run_identifier_name] == run_identifier]
        best_val = run_df[metric].min() if metric == 'Steps Taken' else run_df[metric].max()
        run_performances.append(best_val)

    # Aggregate between runs
    run_performances = np.array(run_performances)
    mean = np.mean(run_performances)  # mean
    std_err = np.std(run_performances) / math.sqrt(len(run_performances))  # std error
    var = 1.96 * std_err  # 95% CI

    return mean, var


def get_formatted_string_for_table(val, sig_figures=5):
    val_str = str(val)[:sig_figures]

    # If it has a point, clear all trailing zeroes
    while val_str[-1] == '0':
        val_str = val_str[:-1]

    # Trim trailing point
    if val_str[-1] == '.':
        val_str = val_str[:-1]

    return val_str


if __name__ == '__main__':
    generate_plots = True
    output_table_stats = False
    use_reduced_metrics = False
    environments = ['TrafficJunction-Easy', 'PredatorPrey', 'TrafficJunction-Medium', 'BoxPushing', 'DroneScatter-Stochastic', 'DroneScatter-Greedy']

    # All
    metrics_per_env = {'TrafficJunction-Easy': ['success', 'total_reward'],
                       'BoxPushing': ['ratio_boxes_cleared', 'total_reward'],
                       'DroneScatter-Stochastic': ['pairwise_distance', 'steps_taken', 'total_reward'],
                       'DroneScatter-Greedy': ['pairwise_distance', 'steps_taken', 'total_reward'],
                       'PredatorPrey': ['success', 'total_reward'],
                       'TrafficJunction-Medium': ['success', 'total_reward']}
    # Reduced
    reduced_metrics_per_env = {'TrafficJunction-Easy': ['success'],
                               'BoxPushing': ['ratio_boxes_cleared'],
                               'DroneScatter-Stochastic': ['pairwise_distance', 'steps_taken'],
                               'DroneScatter-Greedy': ['pairwise_distance', 'steps_taken'],
                               'PredatorPrey': ['success'],
                               'TrafficJunction-Medium': ['success']}
    if use_reduced_metrics:
        metrics_per_env = reduced_metrics_per_env
    nice_metric_names = {'success': 'Success', 'total_reward': 'Reward', 'ratio_boxes_cleared': 'Ratio Cleared',
                         'pairwise_distance': 'Pairwise Distance', 'steps_taken': 'Steps Taken'}

    for environment in environments:
        if output_table_stats:
            print('\nShowing results for environment:', environment)
        metrics = metrics_per_env[environment]

        if environment in {'DroneScatter-Stochastic'}:  # Do not include DGN, since it cannot be stochastic
            models = ['commnet', 'ic3net', 'magic', 'tarmac', 'tarmac_ic3net']
            model_names = ['CommNet', 'IC3Net', 'MAGIC', 'TarMAC', 'T-IC3Net']
        else:
            models = ['commnet', 'dgn', 'ic3net', 'magic', 'tarmac', 'tarmac_ic3net']
            model_names = ['CommNet', 'DGN', 'IC3Net', 'MAGIC', 'TarMAC', 'T-IC3Net']

        for model_index, model in enumerate(models):
            for metric_index, metric in enumerate(metrics):
                file_dir = environment + '/' + model + '/' + metric + '/'

                if environment in {'DroneScatter-Stochastic', 'DroneScatter-Greedy'}:  # Only 0.75 RNI for DS env
                    rni_modes = ['0', '1', '075']
                    rni_mode_names = ['Baseline', 'Unique IDs', '0.75 RNI']
                else:
                    rni_modes = ['0', '1', '075', '025']
                    rni_mode_names = ['Baseline', 'Unique IDs', '0.75 RNI', '0.25 RNI']

                agg_dfs = []

                for i, rni_mode in enumerate(rni_modes):
                    input_filename = file_dir + rni_mode + '.csv'

                    df = pd.read_csv(input_filename)

                    if environment in {'BoxPushing'}:
                        max_df = get_max_results(df)  # For each epoch, report max result achieved so far
                        max_df = get_sorted_results(max_df, 'seed', metric)
                        max_df['Method'] = rni_mode_names[i]
                        agg_dfs.append(max_df)
                    elif environment in {} and metric in {'steps_taken'}:
                        min_df = get_min_results(df)  # For each epoch, report min result achieved so far
                        min_df = get_sorted_results(min_df, 'seed', metric)
                        min_df['Method'] = rni_mode_names[i]
                        agg_dfs.append(min_df)
                    else:
                        df = get_sorted_results(df, 'seed', metric)
                        df['Method'] = rni_mode_names[i]
                        agg_dfs.append(df)

                combined_df = pd.concat(agg_dfs)
                combined_df.reset_index(inplace=True)
                combined_df['epoch'] = combined_df['epoch'].astype('int')

                # Switch to nice metric names
                combined_df.rename(columns=nice_metric_names, inplace=True)

                if generate_plots:
                    # Seaborn plot
                    # generate_hued_plot(combined_df, metric=nice_metric_names[metric], hue='Method',
                    #                    path=environment + "/" + environment + "_" + model + "_" + metric + ".pdf",
                    #                    title=environment + ': ' + model_names[model_index], interval='sd')
                    # Plot with manual CI
                    hued_plot_custom_ci(combined_df, metric=nice_metric_names[metric], hue='Method',
                                        path=environment + "/" + environment + "_" + model + "_" + metric + ".pdf",
                                        title=environment + ': ' + model_names[model_index])

                if output_table_stats:
                    mean_vals = []
                    cis = []

                    for i, rni_mode_name in enumerate(rni_mode_names):
                        mean, var = get_mean_var_values_for_df(df=combined_df, method=rni_mode_name, metric=nice_metric_names[metric])
                        mean_vals.append(mean)
                        cis.append(var)

                    best_mean_val = min(mean_vals) if metric in {'steps_taken'} else max(mean_vals)

                    if metric_index == 0:
                        print(model_names[model_index], end=' ')
                    print('&', nice_metric_names[metric], end=' & ')

                    for i, mean_val in enumerate(mean_vals):
                        ci = cis[i]

                        # Get formatted strings of mean and ci
                        mean_num_figures = 5
                        ci_num_figures = 4
                        mean_val_str = get_formatted_string_for_table(mean_val, mean_num_figures)
                        ci_str = get_formatted_string_for_table(ci, ci_num_figures)
                        max_mean_val_str = get_formatted_string_for_table(best_mean_val, mean_num_figures)

                        if mean_val_str == max_mean_val_str:
                            print('\\pmb{$', mean_val_str, '\\pm', ci_str, '$}', end='')
                        else:
                            print('$', mean_val_str, '\\pm', ci_str, '$', end='')

                        if i != len(mean_vals) - 1:
                            print(' && ', end='')
                    print(' \\\\')

                output_filename = file_dir + 'combined.csv'
                combined_df.to_csv(output_filename)
