import pickle
import torch
import os
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
import io

# Helps loading data on CPU that was initially stored on gpus
class CPU_Unpickler(pickle.Unpickler):
    def find_class(self, module, name):
        if module == 'torch.storage' and name == '_load_from_bytes':
            return lambda b: torch.load(io.BytesIO(b), map_location='cpu')
        else:
            return super().find_class(module, name)


def fetch_data(file_path):
    combined_x_gradients = []
    combined_y_gradients = []

    
    with open(file_path, 'rb') as file:
        # Load the data/model with map_location set to CPU
        x_grad_samples, y_grad_samples = CPU_Unpickler(file).load()
        
        # Move to CPU and convert to numpy arrays, then extend the combined lists
        combined_x_gradients.extend([tensor.cpu().numpy() for tensor in x_grad_samples])
        combined_y_gradients.extend([tensor.cpu().numpy() for tensor in y_grad_samples])

    metric = []
    for i in np.arange(0,len(combined_x_gradients)):
        temp = np.linalg.norm(combined_x_gradients[i])**2 + np.linalg.norm(combined_y_gradients[i])**2
        metric.append(temp)

    # metric = np.log10(np.array(metric))
    metric = 1./np.log(10) * np.log(np.array(metric))
    return metric


def gather_data():
    src_dir = './result_data/'
    datasets = [ 'a9a/', 'gisette/', 'sido0/']
    baselines = ['SAPD', 'SMDAVR']

    smagda_partial_a9a = 'TWENTY_SSAGDA_ALL_grad_sample_simtime=200_maxepoch=20_epochnum=30000_tau1=0.1_tau2=0.002_beta=0.0001_p=160_b=1028.pkl'
    smagda_partial_gisette = 'TWENTY_SSAGDA_ALL_grad_sample_simtime=200_maxepoch=20_epochnum=6000_tau1=0.001_tau2=0.0002_beta=1e-05_p=160_b=256.pkl'
    smagda_partial_sido = 'TWENTY_SSAGDA_ALL_grad_sample_simtime=200_maxepoch=20_epochnum=12678_tau1=0.001_tau2=0.0002_beta=1e-05_p=160_b=1028.pkl'

    smagda_a9a = 'SSAGDA_ALL_grad_sample_simtime=200_maxepoch=250_epochnum=30000_tau1=0.1_tau2=0.002_beta=0.0001_p=160_b=1028.pkl'    
    smagda_gisette = 'SSAGDA_ALL_grad_sample_simtime=200_maxepoch=550_epochnum=6000_tau1=0.001_tau2=0.0002_beta=1e-05_p=160_b=256.pkl'
    smagda_sido = 'SSAGDA_ALL_grad_sample_simtime=200_maxepoch=250_epochnum=12678_tau1=0.001_tau2=0.0002_beta=1e-05_p=160_b=1028.pkl'

    sagda_paths = [[smagda_partial_a9a, smagda_partial_gisette, smagda_partial_sido], 
                   [smagda_a9a, smagda_gisette, smagda_sido]]

    paths_1 = []
    for ii, dataset in enumerate(datasets):
        paths_1.append([])
        for algorithm in baselines:
            paths_1[ii].append(src_dir + dataset + 'TWENTY_' + algorithm + '_all_grad_theta=0.80')
        paths_1[ii].append(src_dir + dataset + sagda_paths[0][ii])

    data1 = []
    for ii, _ in enumerate(datasets):
        data1.append([])
        for jj in range(3):
            data1[ii].append(fetch_data(paths_1[ii][jj]))

    data2 = []
    
    # a9a - SAPD
    data_a9a_sapd = [fetch_data(src_dir + datasets[1] + baselines[0] + '_' + str(ii+1) + '_all_grad_theta=0.80') for ii in range(4)]
    data_a9a_sapd = np.concatenate(data_a9a_sapd)
    data_a9a_smdavr = np.concatenate([fetch_data(src_dir + datasets[0] + baselines[1] + '_' + str(ii+1) + '_all_grad_theta=0.80') for ii in range(4)])
    data_a9a_smagda = fetch_data(src_dir + datasets[0] + smagda_a9a)
    data2.append([data_a9a_sapd, data_a9a_smdavr, data_a9a_smagda])

    # gisette
    data_gisette_sapd = np.concatenate([fetch_data(src_dir + datasets[1] + baselines[0] + '_' + str(ii+1) + '_all_grad_theta=0.80') for ii in range(4)])    
    data_gisette_smdavr = np.concatenate([fetch_data(src_dir + datasets[1] + baselines[1] + '_' + str(ii+3) + '_all_grad_theta=0.80') for ii in range(2)] + [fetch_data('./result_data/gisette/SMDAVR_100_all_grad_theta=0.80')])
    data_gisette_smagda = fetch_data(src_dir + datasets[1] + smagda_gisette)
    data2.append([data_gisette_sapd, data_gisette_smdavr, data_gisette_smagda])

    #sido0
    data2.append([fetch_data(src_dir + 'sido0/' + algorithm + '_all_grad_theta=0.80') for algorithm in baselines]) 
    data2[-1].append(fetch_data(src_dir + datasets[2] + smagda_sido))

    all_data = np.array([data1, data2])
    return all_data


def plot_histograms(all_data, directory='./', name='histogram_dro'):
    
    datasets = [['a9a - epoch=20', 'gisette - epoch=20', 'sido0 - epoch=20'],
                ['a9a - epoch=250', 'gisette - epoch=550', 'sido0 - epoch=250']]
    labels = ['SAPD+', 'SMDAVR', 'smAGDA']
    colors = ['blue', 'green', 'red']
    xlims = [(-3.5, 3.5), (-3.5, 3.5), (-3.5, 2)]
    ylims = [(0, 200), (0, 70), (0, 60)]
    bin_size = [100, 60, 60]
    f, axes = plt.subplots(nrows=2, ncols=3,  figsize=(12, 6)) 
    
    kwargs = dict(ec='black', alpha=0.4)   
    
    for ii, axes_ in enumerate(axes):
        for jj, ax in enumerate(axes_):        
            for kk, algorithm in enumerate(labels):
                ax.set_title(datasets[ii][jj], fontsize=15)
                ax.grid(True)  # Enable the grid
                ax.grid(which='both', color='gray', linestyle='-', linewidth=0.5)
                ax.set_xlim(*xlims[jj])                
                bins = np.linspace(-3.0190100412978813, 3.1043467222072243, bin_size[jj])
                data = all_data[ii][jj][kk]
                weights = np.ones_like(data) / len(data)
                _ = ax.hist(data, weights=weights, color=colors[kk], bins=bins, label=algorithm, **kwargs)                                
    
    handles, labels = ax.get_legend_handles_labels()
    f.legend(handles, labels, loc='lower center', ncol=len(labels), bbox_to_anchor=(0.5, -0.1), prop={'size': 15})
    plt.tight_layout(rect=[0, -0.02, 1, 1])  
    plt.savefig('{}/{}.pdf'.format(directory, name), bbox_inches='tight')
    plt.show()

if __name__ == "__main__":
    all_data = gather_data()
    plot_histograms(all_data)