import torch
import numpy as np
import matplotlib.pyplot as plt
import csv
import pandas as pd
import os
import decimal

# create a new context for this task
ctx = decimal.Context()

# 20 digits should be enough for everyone :D
ctx.prec = 20

def float_to_str(f):
    """
    Convert the given float to a string,
    without resorting to scientific notation
    """
    d1 = ctx.create_decimal(repr(f))
    return format(d1, 'f')


# plots of accuracy vs disparate impact
methods1 = ['langevin_lang', 'langevin_lang_free', 'svgd_lang']
labels1 = ["Primal-Dual + Langevin", "Control + Langevin", "Primal-Dual + SVGD"]

methods2 = ['mied', 'mied_coin']
labels2 = ["MIED", "Coin MIED"]

methods = methods1 + methods2
labels = labels1 + labels2

t_vals = [0.01, 0.005, 0.002, 0.001, 0.0001, 0.00001, 0.000001]

acc_methods1 = np.zeros((len(methods1), len(t_vals)))
cv_methods1 = np.zeros((len(methods1), len(t_vals)))

for i, method in enumerate(methods1):
    for j, t in enumerate(t_vals):
        f_root = "results" + "/" + "fairness_bnn" + "/" + method + "/" + method + "_"
        if method == "langevin_lang" or method == "svgd_lang":
            f_root += "1_1000.0_"
        f_root += "50_" + float_to_str(t) + "_1" + "/"
        fname_acc = f_root + "acc_bnn_ct.npy"
        acc_methods1[i, j] = np.load(fname_acc)[-1]
        fname_cv = f_root + "cv_bnn_ct.npy"
        cv_methods1[i,j] = np.load(fname_cv)[-1]

acc_methods2 = np.zeros((len(methods2), len(t_vals)))
cv_methods2 = np.zeros((len(methods2), len(t_vals)))

for i, method in enumerate(methods2):
    f_name = "results" + "/" + "fairness_bnn" + "/" + method + "/" + "final_metrics" + "/" + method + ".csv"
    df = pd.read_csv(f_name, sep=',', header='infer')
    for j, t in enumerate(t_vals):
        acc_methods2[i,j] = df["acc_bnn"][j]
        cv_methods2[i, j] = df["cv_bnn"][j]

acc_methods = np.vstack((acc_methods1, acc_methods2))
cv_methods = np.vstack((cv_methods1, cv_methods2))

for i in range(len(methods)):
    plt.plot(acc_methods[i,:], cv_methods[i,:], ".-", marker="o", mfc='none', label=labels[i])
plt.gca().invert_xaxis()
plt.legend(prop={'size': 15})
plt.locator_params(axis='x', nbins=6)
plt.xlabel("Accuracy", fontsize=20)
plt.ylabel("CV Score", fontsize=20)
plt.yticks(fontsize=20)
plt.xticks(fontsize=20)
plt.savefig("results/fairness_bnn/acc_vs_cv.pdf",  bbox_inches='tight', dpi=300)
plt.show()


# plot of riesz energy / constraint vs no. of iterations

# MIED
fname_root = "results" + "/" + "fairness_bnn" + "/" + "mied" + "/" + "metrics_vs_iter" + "/"
fname_energy = fname_root + "riesz_energy.csv"
fname_constraint = fname_root + "constraint.csv"
df_mied_energy = pd.read_csv(fname_energy, sep=',', header='infer')
df_mied_constraint = pd.read_csv(fname_constraint, sep=',', header='infer')

# Coin MIED
fname_root = "results" + "/" + "fairness_bnn" + "/" + "mied_coin" + "/" + "metrics_vs_iter" + "/"
fname_energy = fname_root + "riesz_energy.csv"
fname_constraint = fname_root + "constraint.csv"
df_coin_mied_energy = pd.read_csv(fname_energy, sep=',', header='infer')
df_coin_mied_constraint = pd.read_csv(fname_constraint, sep=',', header='infer')

iters = np.arange(10,2010,10)

# energy
plt.plot(iters, df_coin_mied_energy["CoinMIED_t_0_01 - Riesz energy"], color="C0", label="Coin MIED (t=1e-2)")
plt.plot(iters, df_coin_mied_energy["coinMIED_t_0_001 - Riesz energy"], color="C1", label="Coin MIED (t=1e-3)")
plt.plot(iters, df_coin_mied_energy["coinMIED_t_0.0001 - Riesz energy"], color="C2", label="Coin MIED (t=1e-4)")
plt.plot(iters, df_coin_mied_energy["coinMIED_t_0_000001 - Riesz energy"], color="C3", label="Coin MIED (t=1e-6)")
plt.plot(iters, df_mied_energy["MIED_t_0_01_Adam - Riesz energy"], "--", color="C0", label="MIED (t=1e-2)")
plt.plot(iters, df_mied_energy["MIED_t_0_001_Adam - Riesz energy"], "--", color="C1", label="MIED (t=1e-3)")
plt.plot(iters, df_mied_energy["MIED_t_0_0001_Adam - Riesz energy"], "--", color="C2", label="MIED (t=1e-4)")
plt.plot(iters, df_mied_energy["MIED_t_0_000001_Adam - Riesz energy"], "--", color="C3", label="MIED (t=1e-6)")
plt.legend(ncol=2, prop={'size': 13})
plt.ylim(-700,4200)
plt.xlabel("Iterations", fontsize=20)
plt.ylabel("Riesz Energy", fontsize=20)
plt.yticks(fontsize=20)
plt.xticks(fontsize=20)
plt.locator_params(axis='x', nbins=8)
plt.savefig("results/fairness_bnn/energy_vs_iter.pdf",  bbox_inches='tight', dpi=300)
plt.show()

# contraint
plt.plot(iters, df_coin_mied_constraint["CoinMIED_t_0_01 - constraint_mean"], color="C0", label="Coin MIED (t=1e-2)")
plt.plot(iters, df_coin_mied_constraint["coinMIED_t_0_001 - constraint_mean"], color="C1", label="Coin MIED (t=1e-3)")
plt.plot(iters, df_coin_mied_constraint["coinMIED_t_0.0001 - constraint_mean"], color="C2", label="Coin MIED (t=1e-4)")
plt.plot(iters, df_coin_mied_constraint["coinMIED_t_0_000001 - constraint_mean"], color="C3", label="Coin MIED (t=1e-6)")
plt.plot(iters, df_mied_constraint["MIED_t_0_01_Adam - constraint_mean"], "--", color="C0", label="MIED (t=1e-2)")
plt.plot(iters, df_mied_constraint["MIED_t_0_001_Adam - constraint_mean"], "--", color="C1", label="MIED (t=1e-3)")
plt.plot(iters, df_mied_constraint["MIED_t_0_0001_Adam - constraint_mean"], "--", color="C2", label="MIED (t=1e-4)")
plt.plot(iters, df_mied_constraint["MIED_t_0_000001_Adam - constraint_mean"], "--", color="C3", label="MIED (t=1e-6)")
plt.legend(ncol=2, prop={'size': 13})
plt.xlabel("Iterations", fontsize=20)
plt.ylabel("Constraint", fontsize=20)
plt.locator_params(axis='x', nbins=8)
plt.locator_params(axis='y', nbins=8)
from matplotlib.ticker import FuncFormatter
from matplotlib import pyplot as plt
def sci_format(x,lim):
    return '{:.1e}'.format(x)
major_formatter = FuncFormatter(sci_format)
plt.gca().yaxis.set_major_formatter(major_formatter)
plt.yticks(fontsize=20)
plt.xticks(fontsize=20)
plt.savefig("results/fairness_bnn/constraint_vs_iter.pdf", bbox_inches='tight', dpi=300)
plt.show()
