""" 3. Evaluate causal discovery results. """
from . import utils
import pandas as pd
import numpy as np
from pathlib import Path
import os
import sys
np.set_printoptions(suppress=False)  # scientific notation is ok
np.set_printoptions(threshold=sys.maxsize)  # don't truncate big arrays
# import rpy2.robjects.numpy2ri as rn
# import rpy2.robjects as ro
# rn.activate()
# ro.r("""library(SID)""")


class Evaluator:
    def __init__(self, opt):
        """
        Evaluate results in _res of exp_name
        Args:
            exp_name(String): Name of the experiment
            base_dir(String): Load datasets from here
            overwrite_prev(String): Remove all previous evaluations
        """
        self.opt = opt
        # parent dir of all the experiments
        self.exp_dir = os.path.join(
            opt.base_dir, os.path.basename(opt.base_dir) + opt.exp_name
        )
        self.input_folder = os.path.join(self.exp_dir, "_res")
        self.output_folder = os.path.join(self.exp_dir, "_eval")
        utils.create_folder(self.output_folder, opt.overwrite)

    def read_files(self):
        # gather all .csv files recursevely
        results_files = list(Path(self.input_folder).rglob("*results.csv"))
        all_df = []
        names = []
        for file in results_files:
            if "varsortability" in file.name:
                continue
            all_df.append(utils.load_results(file))
            names.append(file.name)
        return pd.concat(all_df, axis=0).to_dict("list"), names

    def evaluate(self, thresholding="standard", eligible=[]):
        """
        Evaluate a given solution in results
        Args:
            thresholding(String): One of ["standard", "dynamic", "favourable"]
            eligible(list): List of algorithms eligible for non-standard thresholding
        """
        w_threshold_mapping = utils.thresholds(self.opt.thres)
        inputs, names = self.read_files()
        keys = list(inputs.keys()) + [
            "fdr",
            "tpr",
            "fpr",
            "nnz",
            "shd",
            "sid",
            "sid_upper",
            "sid_lower",
            "mec_sid",
            "mec_sid_upper",
            "mec_sid_lower",
            "mec_shd",
            "was_dag",
            "effective_thresholding",
            "w_threshold",
        ]
        results = {i: [] for i in keys}
        faults = 0
        for i in range(len(inputs["W_est"])):
            W_true = inputs["W_true"][i][0]
            B_true = W_true != 0
            W_est = inputs["W_est"][i][0]
            if "list" in str(W_true):
                faults += 1
                print("faulty:", names[i])
                continue
            # copy all the previous entries
            for k, v in inputs.items():
                results[k].append(v[i])

            W_est_thr = np.copy(W_est)
            was_dag = utils.is_dag(W_est_thr)

            # thresholding
            algo = inputs["algorithm"][i]
            w_threshold = w_threshold_mapping[algo]

            if thresholding == "standard":
                effective_thresholding = "standard"
            elif thresholding == "dynamic":
                if algo in eligible:
                    scaling_factors = inputs["scaling_factors"][i]
                    edge_weight_range = eval(inputs["edge_weight_range"][i])
                    w_threshold = w_threshold * 1 / np.max(scaling_factors)
                    min_edge_weight = np.min(np.abs(edge_weight_range))
                    w_threshold = w_threshold * min_edge_weight
                    effective_thresholding = "dynamic"
                else:
                    effective_thresholding = "standard"
            elif thresholding == "favourable":
                if algo in eligible:
                    w_threshold = np.mean(
                        np.sort(np.abs(W_est_thr).ravel())[::-1][
                            (B_true != 0).sum() - 1 : (B_true != 0).sum() + 1
                        ]
                    )
                    effective_thresholding = "favourable"
                else:
                    effective_thresholding = "standard"
            else:
                raise ValueError("no such thresholding")

            # threshold
            if w_threshold == "bidirected":
                W_est_thr[W_est_thr < 0.0] = 0
            else:
                W_est_thr[np.abs(W_est_thr) < w_threshold] = 0

            # dagify
            while not utils.is_dag(W_est_thr):
                W_est_thr = utils.dagify_break_cycles(W_est_thr)

            B_est = W_est_thr != 0

            # Accuracy measures
            acc = utils.count_accuracy(B_true, B_est)
            for k, v in acc.items():
                results[k].append(v)

            # # SID
            # nr, nc = B_true.shape
            # r_B_true = ro.r.matrix(B_true, nrow=nr, ncol=nc)
            # r_B_est = ro.r.matrix(B_est, nrow=nr, ncol=nc)
            # ro.r("""sid_fun <- function(b_true, b_est){
            #         res = structIntervDist(b_true, b_est)
            #         return(c(res$sid, res$sidUpperBound, res$sidLowerBound))}""")
            # r_sid_fun = ro.globalenv["sid_fun"]
            # sid_res = list(r_sid_fun(r_B_true, r_B_est))
            # results["sid"].append(sid_res[0])
            # results["sid_upper"].append(sid_res[1])
            # results["sid_lower"].append(sid_res[2])
            results["sid"].append(-1)
            results["sid_upper"].append(-1)
            results["sid_lower"].append(-1)

            ### Procedure:
            ## MEC-SID:
            # Make the estimated graph into a CPDAG if it is not already.
            # Compare to True DAG and report as
            # 1) "MEC_sid", 2) "MEC_sid_lower", 3) "MEC_sid_upper"
            #
            ## MEC-SHD:
            # Transform true and estimated graph into CPDAGs, and compare.
            ###
            if self.opt.MEC:
                import rpy2.robjects.numpy2ri as rn
                import rpy2.robjects as ro

                rn.activate()
                ro.r("""library(SID)""")
                # define dagcpdag in R
                algo = inputs["algorithm"][i]
                ro.r("""dag2cpdag <- function(x){return(pcalg::dag2cpdag(x))}""")
                r_dag2cpdag = ro.globalenv["dag2cpdag"]
                # for pc and fges, use their output directly
                if algo in ["pc", "fges"]:
                    r_CPDAG_est = ro.r.matrix(
                        W_est != 0, nrow=nr, ncol=nc
                    )  # use non-thresholded version
                    mec_sid_res = list(r_sid_fun(r_B_true, r_CPDAG_est))
                    mec_shd = utils.shd_cpdag(
                        np.array(r_dag2cpdag(r_B_true)), np.array(r_CPDAG_est)
                    )
                else:
                    r_B_est = ro.r.matrix(B_est, nrow=nr, ncol=nc)
                    mec_sid_res = list(r_sid_fun(r_B_true, r_dag2cpdag(r_B_est)))
                    mec_shd = utils.shd_cpdag(
                        np.array(r_dag2cpdag(r_B_true)), np.array(r_dag2cpdag(r_B_est))
                    )

                results["mec_sid"].append(mec_sid_res[0])
                results["mec_sid_upper"].append(mec_sid_res[1])
                results["mec_sid_lower"].append(mec_sid_res[2])
                results["mec_shd"].append(mec_shd)

            else:
                results["mec_sid"].append(-1)
                results["mec_sid_lower"].append(-1)
                results["mec_sid_upper"].append(-1)
                results["mec_shd"].append(-1)

            results["was_dag"].append(was_dag)
            results["effective_thresholding"].append(effective_thresholding)
            results["w_threshold"].append(w_threshold)

        all = len(inputs["W_est"])
        print(f"fraction of faulty results: {faults/all}")
        res_df = pd.DataFrame(results)
        # backwards compatibility
        if "graph_type" in res_df.columns:
            res_df.rename(columns={"graph_type": "graph"}, inplace=True)
            res_df.noise_dist.replace({"uniform": "unif"}, inplace=True)
        res_df = res_df.sort_values(by=["algorithm"])
        print(f"Writing {thresholding}.csv")
        res_df.to_csv(
            os.path.join(
                self.output_folder, f"{thresholding}_{str(self.opt.thres)}.csv"
            ),
            index=False,
        )
