import numpy as np
import networkx as nx
from cdt.metrics import precision_recall
from dodiscover.metrics import structure_hamming_dist
import warnings
warnings.filterwarnings("error")


def reversed(A_truth, A_pred):
    # Test: reversed = 0 for undirected graphs
    d = A_truth.shape[0]
    for i in range(d):
        for j in range(i+1, d):
            if A_pred[i, j] + A_pred[j, i] == 2:
                A_pred[i,j] = -1
                A_pred[j,i] = 0
                
    cond_reversed = np.flatnonzero(A_truth.T)
    pred = np.flatnonzero(A_pred == 1)
    cond = np.flatnonzero(A_truth)
    extra = np.setdiff1d(pred, cond, assume_unique=True)
    reversed_edges = np.intersect1d(extra, cond_reversed, assume_unique=True)
    return len(reversed_edges)

def false_negatives(A_truth, A_pred):
    """Compute number of false negatives for DAGs and CPDAGs.

    Reversed edges are not included in the false negatives. 
    """ 
    truth_negatives = np.flatnonzero(np.triu(A_truth + A_truth.T, k=1)==0)
    pred_negatives = np.flatnonzero(np.triu(A_pred + A_pred.T, k=1)==0)
    fn_list = np.setdiff1d(pred_negatives, truth_negatives, assume_unique=True) # values of pred_negatives not in truth_negatives. Not Symmetric!
    return len(fn_list)

def false_positives(A_truth, A_pred):
    """Compute number of false positives for DAGs and CPDAGs.

    Reversed edges are not included in the false positives. 
    Verified equals to NOTEARS
    """    
    truth_positives = np.flatnonzero(np.triu(A_truth + A_truth.T, k=1)) # Make truth undirected
    pred_positives = np.flatnonzero(np.triu(A_pred + A_pred.T, k=1)) # Make pred undirected
    fp_list = np.setdiff1d(pred_positives, truth_positives) # Check values that are in the undirected pred, but not in the undirected truth
    return len(fp_list)

def true_positives(A_truth, A_pred): 
    """Verified equals to NOTEARS
    """
    # Count undirected edges as TP whenever there is an edge in the groundtruth
    tp_undir = np.intersect1d(np.flatnonzero(np.triu(A_pred*A_pred.T)), np.flatnonzero(np.triu(A_truth+A_truth.T)))

    # Count directed edges true positives
    d = A_truth.shape[0]
    for i in range(d):
        for j in range(i+1, d):
            if A_pred[i, j] + A_pred[j, i] == 2:
                A_pred[i,j] = -1
                A_pred[j,i] = 0
                
    pred = np.flatnonzero(A_pred == 1)
    cond = np.flatnonzero(A_truth)
    tp_dir = np.intersect1d(pred, cond, assume_unique=True)
    return len(tp_dir) + len(tp_undir)

def true_negatives(A_truth, A_pred):
    d = A_truth.shape[0]
    truth_negatives = np.flatnonzero(A_truth + A_truth.T + np.tril(np.ones((d,d)), k=0) == 0)
    pred_negatives = np.flatnonzero(A_pred + A_pred.T + np.tril(np.ones((d,d)), k=0) == 0)
    return len(np.intersect1d(pred_negatives, truth_negatives, assume_unique=True))

# ---------------------------- Rates ---------------------------- #
def fnr(A_truth, A_pred):
    """Compute false negative rate.

    Include reversed edges in the computation, as this is needed for 
    the topological order comparison.
    """
    num_edges = len(np.flatnonzero(np.triu(A_truth + A_truth.T,k=1)))
    fnr = (false_negatives(A_truth, A_pred) + reversed(A_truth, A_pred))/num_edges
    return fnr

def fpr(A_truth, A_pred):
    """Compute false positive rate.

    Exclude reversed edges in the computation, as this is not relevant
    to observe information of interest as the effect of confounders,
    while it allows to have a fair computation of f1 score.
    """
    d = A_truth.shape[0]
    num_edges = len(np.flatnonzero(np.triu(A_truth + A_truth.T,k=1)))
    max_fp = 0.5*d*(d-1) - num_edges
    try:
        return (false_positives(A_truth, A_pred))/max_fp
    except ZeroDivisionError:
        return 0 

def tpr(A_truth, A_pred):
    tp = true_positives(A_truth, A_pred)
    num_edges = len(np.flatnonzero(np.triu(A_truth + A_truth.T,k=1)))
    return tp/num_edges

def tnr(A_truth, A_pred):
    d = A_truth.shape[0]
    tn = true_negatives(A_truth, A_pred)
    num_edges = len(np.flatnonzero(np.triu(A_truth + A_truth.T,k=1)))
    num_negatives = 0.5*d*(d-1) - num_edges
    try:
        return tn/num_negatives
    except ZeroDivisionError:
        return 0 


def f1(A_truth, A_pred):
    tp = true_positives(A_truth, A_pred)
    fp = false_positives(A_truth, A_pred)
    fn = false_negatives(A_truth, A_pred) + reversed(A_truth, A_pred)
    try:
        f1 = tp/(tp + 0.5*(fp+fn))
    except ZeroDivisionError:
        f1=0
    except RuntimeWarning: # 0 / 0 division
        # breakpoint()
        f1=0 # if both numerator and denominator are equals 0.
    
    return f1


#####################################################################
############################ GET METRICS ############################
#####################################################################

def get_metrics(A_pred : np.array, A_truth : np.array):
    """Get metrics comparing predicted adjacencY A_pred with ground truth A_truth.
    is_cpdag is True if A_pred is a CPDAG.
    A_truth is always the DAG groundtruth.
    Different metrics are computed whether A_pred is a DAG or a CPDAG
    CPDAG metrics : pass
    DAG metrics : pass

    Parameters
    ----------
    A_pred : np.array
        Adjacency matrix prediction from causal discovery algorithm.
        Can be both DAG or CPDAG
    A_truth : np.array
        DAG groundtruth adjacency matrix

    Return
    -----
    shd : int
        Structural Hamming Distance
    tpr : float
        True positives rate. Equals 1 for perfect prediction
    fpr : float
        False positives rate. Equals 0 for perfect prediction
    tnr : float
        True negatives rate. Equals 1 for perfect prediction
    fnr : float
        False negatives rate. Equals 0 for perfect prediction
    f1 : float
        F1 score, computed as 2*(precision*recall)/(precision+recall)
    aupr : float
        Area under the precision-recall curve
    """
    tpr_accuracy = tpr(A_truth, A_pred)
    fpr_accuracy = fpr(A_truth, A_pred)
    tnr_accuracy = tnr(A_truth, A_pred)
    fnr_accuracy = fnr(A_truth, A_pred)
    f1_accuracy = f1(A_truth, A_pred)
    shd = structure_hamming_dist(
        true_graph=nx.from_numpy_array(A_truth, create_using=nx.DiGraph),
        pred_graph=nx.from_numpy_array(A_pred, create_using=nx.DiGraph),
        double_for_anticausal=False
    )
    aupr, _ = precision_recall(A_truth, A_pred)
    return shd, tpr_accuracy, fpr_accuracy, tnr_accuracy, fnr_accuracy, f1_accuracy, aupr


def d_top(order : np.array, A : np.array):
    """Topological order divergence.

    Parameters
    ----------
    order: np.array
        Inferred topological order (order[0] source node)
    A : np.array
        Ground truth adjacency matrix
    """
    err = 0
    for i in range(len(order)):
        err += A[order[i+1:], order[i]].sum()
    return err

def dtop_fnr(dtop : int, A : np.ndarray):
    fnr = dtop / np.sum(A)
    return fnr



#########################################
############## PAG METRICS ##############
#########################################

def tp_pag(A_pred, A_truth):
    """A_pred, A_truth are the adjacencies of one of the following queries
    1. IsPotentialParent
    2. IsAncestor
    3. IsPotentialAncestor

    A TP is given when A_pred[i,j] = A_truth[i,j] = 1
    """
    pred = np.flatnonzero(A_pred)
    cond = np.flatnonzero(A_truth)
    missing = np.setdiff1d(cond, pred, assume_unique=True) # values of cond not in pred
    tp = A_truth.sum() - len(missing)
    
    return tp


def fp_pag(A_pred, A_truth):
    """A_pred, A_truth are the adjacencies of one of the following queries
    1. IsPotentialParent
    2. IsAncestor
    3. IsPotentialAncestor

    A FP is given when A_pred[i,j] = 1, A_truth[i,j] = 0
    """
    pred = np.flatnonzero(A_pred)
    cond = np.flatnonzero(A_truth)
    extra = np.setdiff1d(pred, cond, assume_unique=True) # values of pred not in cond
    fp = len(extra)
    
    return fp


def fn_pag(A_pred, A_truth):
    """A_pred, A_truth are the adjacencies of one of the following queries
    1. IsPotentialParent
    2. IsAncestor
    3. IsPotentialAncestor

    A FN is given when A_pred[i,j] = 0, A_truth[i,j] = 1
    """
    fn = A_truth.sum() - tp_pag(A_pred, A_truth)
    return fn
    

def f1_pag(A_pred, A_truth):
    tp = tp_pag(A_pred, A_truth)
    fp = fp_pag(A_pred, A_truth)
    fn = fn_pag(A_pred, A_truth)
    try:
        f1 = tp/(tp + 0.5*(fp+fn))
    except:
        return 0

    return f1