#!/usr/bin/python3

import pandas
import seaborn
import matplotlib.pyplot as plt

ORIGINAL_FILENAME = "csv_files/data_factors.csv"
CLASS_LABEL = "shapeIsHeart"

PREDICTION_ERROR_FILENAME = "csv_files/prediction_error.csv"
DISCRIMINATOR_ERROR_FILENAME = "csv_files/discriminator_error.csv"
RECONSTRUCTION_ERROR_FILENAME_STUB = "csv_files/reconstruction_error_"

FEATURES = ["shapeIsHeart", "scale", "orientation", "xPos", "yPos"]
FEATURE_NAMES = ["shape", "scale", "orient.", "x pos.", "y pos."]

FONTSIZE = 20
TITLESIZE = 30

def boxplot_error(error_df, fig_title, outfile):
    errorplot = seaborn.boxplot(data=error_df)
    label = plt.xlabel("Feature", fontsize = FONTSIZE)
    plt.ylabel("Error", fontsize = FONTSIZE)
    plt.title(fig_title, fontsize = TITLESIZE)
    errorplot.set_xticklabels(FEATURE_NAMES, fontsize = FONTSIZE)
    # plt.show()
    errorplot.figure.savefig(outfile, bbox_extra_artists=(label,), bbox_inches='tight')
    plt.clf()

def barplot_error(error_df, fig_title, outfile):
    errorplot = seaborn.barplot(data=error_df)
    label = plt.xlabel("Feature", fontsize = FONTSIZE)
    plt.ylabel("Error", fontsize = FONTSIZE)
    plt.title(fig_title, fontsize = TITLESIZE)
    axes = plt.gca()
    #axes.set_ylim([0.0,1.5])
    errorplot.set_xticklabels(FEATURE_NAMES, fontsize = FONTSIZE)
    errorplot.figure.savefig(outfile, bbox_extra_artists=(label,), bbox_inches='tight')
    # plt.show()
    plt.clf()

def get_reconstruction_filename(stub, featurename):
    return stub + featurename + ".csv"

def avg_perinstance(filename_stub, features):
    mean_df = pandas.DataFrame()
    for feature in features:
        filename = get_reconstruction_filename(filename_stub, feature)
        df = pandas.read_csv(filename)
        feature_mean_df = df.mean(axis=1)
        mean_df[feature] = feature_mean_df
    return mean_df

def combine_error_percol(filename_stub, features):
    combined_df = pandas.DataFrame()
    frames = []
    for feature in features:
        filename = get_reconstruction_filename(filename_stub, feature)
        df = pandas.read_csv(filename)
        frames.append(df)
    return pandas.concat(frames)

# mse(x,xhat) for each protected feature
def error_per_feature(filename_stub, features):
    mse_per_feature = []
    for feature in features:
        filename = get_reconstruction_filename(filename_stub, feature)
        df = pandas.read_csv(filename)
        squared = df * df
        mse = squared.values.mean()
        mse_per_feature.append(mse)
    mse_df = pandas.DataFrame(mse_per_feature).transpose()
    mse_df.columns = features
    return mse_df

# sqrt[mse(p,phat)/var(p)]
def normalize(df, original_data):
    var = original_data.var()
    squared = df * df
    mse = squared.mean()
    div = mse / var
    normalized = div ** .5
    return normalized.transpose()

# (p-phat)^2/var(p)
def normalize_dist(df, original_data):
    print(original_data.columns)
    var = original_data.var()
    squared = df * df
    div = squared / var
    return div

original_data_df = pandas.read_csv(ORIGINAL_FILENAME)

prediction_error_df = pandas.read_csv(PREDICTION_ERROR_FILENAME)
boxplot_error(prediction_error_df, "Prediction Error", "figures/prediction_error.png")

reconstruction_error_df = error_per_feature(RECONSTRUCTION_ERROR_FILENAME_STUB, FEATURES)
barplot_error(reconstruction_error_df, "Reconstruction Error", "figures/reconstruction_error.png")

discriminator_error_df = pandas.read_csv(DISCRIMINATOR_ERROR_FILENAME)
normalized = normalize_dist(discriminator_error_df, original_data_df)
barplot_error(normalized, "Disentanglement Error", "figures/discriminator_error.png")
boxplot_error(discriminator_error_df, "Disentanglement Error", "figures/discriminator_box_error.png")

