import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

NCOLS_L = 6


# plt.rcParams.update({"text.usetex": True, "font.family": "Helvetica"})
mpl.rcParams.update(
    {
        "text.latex.preamble": "\n".join(
            [
                #     r"\usepackage{mathpazo}",
                r"\usepackage[OT1]{fontenc}",
                r"\usepackage[utf8]{inputenc}",
                r"\usepackage{amsmath,amssymb,amsfonts,mathrsfs,bm}",
            ]
        ),
        "pgf.texsystem": "pdflatex",
        "pgf.rcfonts": False,
        "text.usetex": True,
        # "font.serif": "Palatino",
        # "font.sans-serif": "sans-serif",
        #    "font.size": 11,
        #    "axes.titlesize": 14,
        #    "axes.labelsize": 10,
        #    "xtick.labelsize": 8,
        #    "ytick.labelsize": 8,
        #    "legend.fontsize": 10,
        "font.size": 16,
        "axes.titlesize": 16,
        "axes.labelsize": 16,
        "xtick.labelsize": 16,
        "ytick.labelsize": 16,
        "legend.fontsize": 14,
        "legend.framealpha": 1,
        "savefig.dpi": 300,
        "savefig.pad_inches": 0.01,
        "figure.figsize": (5, 3.5),
        "axes.xmargin": 0,
        "lines.markersize": 5,
        "lines.linewidth": 1.5,
    }
)


def export_legend(legend, filename="legend.png", expand=[-5, -5, 5, 5]):
    fig = legend.figure
    fig.canvas.draw()
    bbox = legend.get_window_extent()
    bbox = bbox.from_extents(*(bbox.extents + np.array(expand)))
    bbox = bbox.transformed(fig.dpi_scale_trans.inverted())
    fig.savefig(filename, dpi="figure", bbox_inches=bbox)


def generate_plot(
    df_mean,
    df_std,
    str_metric_x,
    str_metric_y,
    str_filename,
    str_xlabel,
    str_ylabel,
    str_model,
    model_names,
    xlim,
    ylim,
    fn_legend=None,
):
    fig, ax = plt.subplots()
    for model, model_name_vis in model_names.items():
        # for element in model_names.values():
        if isinstance(model_name_vis, dict):
            if model == "split":
                model_x_vals_mean = df_mean.loc[df_mean["model.name"] == model]
            else:
                model_x_vals_mean = df_mean.loc[df_mean["model.name"] == model]
            model_y_vals_mean = df_mean.loc[df_mean["model.name"] == model]
            for agg, agg_name in model_name_vis.items():
                agg_x_vals_mean = model_x_vals_mean.loc[
                    model_x_vals_mean[str_model] == agg, str_metric_x
                ].values
                agg_y_vals_mean = model_y_vals_mean.loc[
                    model_y_vals_mean[str_model] == agg, str_metric_y
                ].values

                res = ax.scatter(
                    agg_x_vals_mean,
                    agg_y_vals_mean,
                    s=75,
                    label=agg_name,
                    alpha=0.9,
                    edgecolors="none",
                )
                ax.quiver(
                    agg_x_vals_mean[:-1],
                    agg_y_vals_mean[:-1],
                    agg_x_vals_mean[1:] - agg_x_vals_mean[:-1],
                    agg_y_vals_mean[1:] - agg_y_vals_mean[:-1],
                    scale_units="xy",
                    linewidth=0.05,
                    angles="xy",
                    scale=1,
                    color=res.get_facecolor()[0],
                    # label=model,
                )
        else:
            if model == "split":
                model_x_vals_mean = df_mean.loc[
                    df_mean["model.name"] == model, str_metric_x
                ]
            else:
                model_x_vals_mean = df_mean.loc[
                    df_mean["model.name"] == model, str_metric_x
                ]
            model_y_vals_mean = df_mean.loc[
                df_mean["model.name"] == model, str_metric_y
            ]
            model_x_vals_mean = model_x_vals_mean.values
            model_y_vals_mean = model_y_vals_mean.values
            res = ax.scatter(
                model_x_vals_mean,
                model_y_vals_mean,
                s=75,
                label=model_name_vis,
                alpha=0.9,
                edgecolors="none",
            )
            ax.quiver(
                model_x_vals_mean[:-1],
                model_y_vals_mean[:-1],
                model_x_vals_mean[1:] - model_x_vals_mean[:-1],
                model_y_vals_mean[1:] - model_y_vals_mean[:-1],
                scale_units="xy",
                linewidth=0.05,
                angles="xy",
                scale=1,
                color=res.get_facecolor()[0],
                # label=model,
            )

    ax.grid(True)
    ax.set_ylim(ylim)
    ax.set_xlim(xlim)
    ax.set_ylabel(str_ylabel)
    ax.set_xlabel(str_xlabel)
    plt.draw()
    print(str_filename)
    fig.savefig(
        str_filename,
        bbox_inches="tight",
    )
    if fn_legend is not None:
        handles, labels = ax.get_legend_handles_labels()
        legend = fig.legend(
            handles,
            labels,
            bbox_to_anchor=(0.5, 1.05),
            ncol=NCOLS_L,
            loc="lower center",
        )
        export_legend(legend, fn_legend)
    plt.close()


def generate_plot_all_attributes(
    df_mean,
    df_std,
    str_filename,
    str_xlabel,
    str_ylabel,
    attribute_names,
    model_names,
    fn_legend=None,
):
    x_idxs = np.arange(len(attribute_names))
    b_w = 1 / (6.0 + 1)
    fig, ax = plt.subplots(1, 1, figsize=(20, 5))
    m = 0
    for model, model_name_vis in model_names.items():
        if isinstance(model_name_vis, dict):
            model_vals_mean = df_mean.loc[df_mean["model.name"] == model]
            for agg, agg_name in model_name_vis.items():
                agg_vals_mean = model_vals_mean.loc[
                    model_vals_mean["model.aggregation"] == agg
                ]
                vals_mean = agg_vals_mean[attribute_names].values.flatten()
                ax.bar(x_idxs + (m - 2.5) * b_w, vals_mean, b_w, label=agg_name)
                m += 1
        else:
            mean_vals_model = df_mean.loc[df_mean["model.name"] == model]
            std_vals_model = df_std.loc[df_mean["model.name"] == model]
            vals_mean = mean_vals_model[attribute_names].values.flatten()
            vals_std = std_vals_model[attribute_names].values.flatten()
            ax.bar(x_idxs + (m - 2.5) * b_w, vals_mean, b_w, label=model_name_vis)
            m += 1
    ax.set_xticklabels(attribute_names, rotation=90)
    ax.set_xticks(x_idxs)
    ax.set_ylim([0.0, 1.0])
    ax.set_ylabel(str_ylabel)
    ax.set_xlabel(str_xlabel)
    plt.draw()
    fig.savefig(
        str_filename,
        bbox_inches="tight",
    )
