# This scrips computes the overlap between the Rashomon set and the set of grounded models

import numpy as np
import pandas as pd
import pickle
from treefarms import TREEFARMS
from sklearn import metrics


def run_grounded_models(df, rashomon_parameter=0.05, depth=4, regularization=0.01, good_features=None):

    #X, Y = df.iloc[:, :-1], df.iloc[:, -1]
    X, Y = df[df.columns[0:-1]], df[df.columns[-1]]
    N = X.shape[0]

    config = {
        "regularization":
            regularization,  # regularization penalizes the tree with more leaves. We recommend to set it to relative high value to find a sparse tree.
        "rashomon_bound_adder":
            rashomon_parameter,  # rashomon bound multiplier indicates how large of a Rashomon set would you like to get
        "rashomon_bound_multiplier": 0,
        "depth_budget": depth,
        "allow_small_reg": True,
        "verbose": False
    }

    model = TREEFARMS(config)
    model.fit(X, Y)
    tree_size = model.get_tree_count()

    patterns = []

    num_trees_contain = 0

    for i in range(tree_size):
        rtree = model[i]
        tree_str = str(rtree)
        #print(tree_str)

        found_feature = 0
        for item in good_features:
            if item in tree_str:
                found_feature = 1
        num_trees_contain += found_feature

    return tree_size, num_trees_contain


def computeTPR(y_values, f_values, threshold):
    TP = 0
    FN = 0
    for i in range(len(y_values)):
        if (y_values[i] == 1):
            if (f_values[i] > threshold):
                TP = TP + 1
            else:
                FN = FN + 1
    return TP / (TP + FN)


def computeFPR(y_values, f_values, threshold):
    FP = 0
    TN = 0
    for i in range(len(y_values)):
        if (y_values[i] == 0):
            if (f_values[i] > threshold):
                FP = FP + 1
            else:
                TN = TN + 1
    return FP / (FP + TN)


def compute_ROC_AUC_feature(df_data, column, label):

    #for column in df_data.columns[index_feature_start:]:
    labels = df_data[label].values
    feature_values = df_data[column].values

    TPR_array = []
    FPR_array = []
    thresholds = np.linspace(min(feature_values) - 0.01, max(feature_values) + 0.01, num=1000)
    for threshold in thresholds:
        TPR_array += [computeTPR(labels, feature_values, threshold)]
        FPR_array += [computeFPR(labels, feature_values, threshold)]

    AUC = metrics.auc(FPR_array, TPR_array)

    if AUC < 0.5:
        labels = 1 - labels
        TPR_array = []
        FPR_array = []
        for threshold in thresholds:
            TPR_array += [computeTPR(labels, feature_values, threshold)]
            FPR_array += [computeFPR(labels, feature_values, threshold)]
        AUC = metrics.auc(FPR_array, TPR_array)

    return [AUC, FPR_array, TPR_array]


def compute_good_features(df, gamma=0.05):
    label = df.columns[-1]
    result = {}
    for f_id, feature in enumerate(df.columns[0:len(df.columns) - 1]):
        auc, _, _ = compute_ROC_AUC_feature(df, feature, label)
        result_name = "feature_" + str(f_id)
        result[result_name] = auc

    best_auc = max(result.values())
    threshold = best_auc - gamma
    good_features = {key for key, value in result.items() if value >= threshold}
    #print(result)
    return good_features


data_list = [
    'carryout_takeaway', 'amsterdam', 'coffee_house', 'australian_credit', 'compas', 'NIJ_Recidivism', 'bankfull',
    'fico', 'occupancy_detection', 'bar7', 'german_credit', 'polish_companies', 'bar', 'GiveMeSomeCredit',
    'restaurant_20', 'bcw_bin', 'telco_churn', 'broward', 'iranian_churn'
]

resulting_dic = {}

for data in data_list:

    depth = 4
    dataset = '../datasets/binarized/' + data + '.csv'

    df = pd.read_csv(dataset)

    good_features = compute_good_features(df, gamma=0.05)

    #print(good_features)

    rset_size, rset_grounded_size = run_grounded_models(df,
                                                        rashomon_parameter=0.05,
                                                        depth=depth,
                                                        regularization=0.01,
                                                        good_features=good_features)
    #print(data, df.shape[1]-1, len(good_features), rset_size, rset_grounded_size)

    resulting_dic[data] = [df.shape[1] - 1, len(good_features), rset_size, rset_grounded_size]

#print results table
for key, value in resulting_dic.items():
    #df.shape[1]-1, len(good_features), rset_size, rset_grounded_size
    print(key, "&", value[0], "&", value[1], "&", "{:.2f}".format(value[3] / value[2] * 100), "\\%",
          f"({value[3]}/{value[2]})", "\\\\\hline")

# save results
with open("./results_overlap_4", 'wb') as f:
    # Dump data to the file
    pickle.dump(resulting_dic, f)
