# this is visualization of Rashomon ratios
# run after compute_sets figure

import pickle
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl


#Read files with numerator:

with open('tree_farms_rset_param', 'rb') as f:
    res = pickle.load(f)
    
for key, value in res.items():
    print(key, value)
    
file = '../datasets/monks3.csv'

num_features = {}

for data in res.keys():
    df = pd.read_csv(file)
    #print(data, df.shape)
    num_features[data] = df.shape[1]-1
    print(data, df.shape[1]-1)
    
# compute denominator

def number_of_trees(max_depth, num_features):
    if max_depth <= 0:
        return 2
    return 2 + num_features * number_of_trees(max_depth-1, num_features-1) ** 2


plt.figure()
depth_arr = [1,2,3,4,5,6,7]
for i_data, key in enumerate(res.keys()):
    if key == 0.01: continue
    y = []
    for depth_id, depth in enumerate(depth_arr):
        y_val = np.log(res[key][depth_id] / number_of_trees(depth, num_features[key]))
        y += [y_val]
        
        new_value = (int(num_features[key]) - 10) / 12 * (1 - 0.3) + 0.3
        
        #plt.scatter(depth,y_val, s= 100, alpha = 0.05, zorder = 2, color = 'k')#color = color[i_data]
    plt.plot(depth_arr, y, linewidth = 1, label = r"$\theta=$" + str(key))
    
plt.xlabel("Tree depth", size = 20)
plt.ylabel("log Rashomon ratio, %", size = 20)  
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
plt.legend(loc = (1.05,0), fontsize=16)
plt.savefig('rratio_7.png',bbox_inches = 'tight')


plt.figure()
depth_arr = [1,2,3]
for i_data, key in enumerate(res.keys()):
    if key == 0.01: continue
    y = []
    for depth_id, depth in enumerate(depth_arr):
        y_val = np.log(res[key][depth_id] / number_of_trees(depth, num_features[key]))
        y += [y_val]
        
        new_value = (int(num_features[key]) - 10) / 12 * (0.8 - 0.2) + 0.2
        
        #plt.scatter(depth,y_val, s= 100, alpha = 0.05, zorder = 2, color = 'k')#color = color[i_data]
    plt.plot(depth_arr, y, linewidth = 1, label = r"$\theta=$" + str(key))
    
plt.xlabel("Tree depth", size = 20)
plt.ylabel("log Rashomon ratio, %", size = 20)  
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
plt.locator_params(axis='x', nbins=3)
plt.savefig('rratio_3.png',bbox_inches = 'tight')