import networkx as nx
import numpy as np
import torch
from matplotlib import pylab as plt
from torch import FloatTensor

from src.explanation_algorithms.GPSHAP import GPSHAP


def plot_correlation_graph_of_stachastic_shapley_values_from_correlation_matrix(
        correlation_matrix: FloatTensor, feature_names: list[str], correlation_threshold: float
):
    # filter out entries
    correlation_graph = torch.zeros_like(correlation_matrix)
    correlation_graph[correlation_matrix.abs() > correlation_threshold] = correlation_matrix[
        correlation_matrix.abs() > correlation_threshold]

    upper_correlation = np.triu(correlation_graph.numpy())

    f, ax = plt.subplots(1, 1, figsize=(6, 4))
    G = nx.Graph()
    edge_id_pairs = []

    for i, feature in enumerate(feature_names):
        G.add_node(feature)

    for i, feature1 in enumerate(feature_names):
        for j, feature2 in enumerate(feature_names):
            if upper_correlation[i, j] != 1:
                if (i != j) and (i < j):
                    G.add_edge(feature1, feature2)
                    edge_id_pairs.append([i, j])

    # uncertainty
    edge_colors = [
        "red" if upper_correlation[edge_id_pair[0], edge_id_pair[1]] >= 0 else "blue"
        for edge_id_pair in edge_id_pairs
    ]
    edge_color_intensity = [
        np.abs(upper_correlation[edge_id_pair[0], edge_id_pair[1]])
        for edge_id_pair in edge_id_pairs
    ]
    edge_sizes = [
        300 * np.abs(upper_correlation[edge_id_pair[0], edge_id_pair[1]]) / (np.abs(upper_correlation).sum())
        for edge_id_pair in edge_id_pairs
    ]

    pos = nx.circular_layout(G)
    nx.draw_networkx_nodes(G, pos, alpha=0.99, node_color="grey", node_size=800)
    nx.draw_networkx_edges(G, pos, alpha=edge_color_intensity, edge_color=edge_colors, width=edge_sizes)
    for i in feature_names:
        pos[i] += (0, 0.18)
    nx.draw_networkx_labels(G, pos, font_size=15,
                            font_family="sans-serif", bbox={"ec": "k", "fc": "white", "alpha": 0.7})
    plt.axis("off")
    ax.margins(0.1, 0.1)
    plt.tight_layout()

    ax.set_title(f"""
        Correlation Graph of Stochastic Shapley values
        both edge's sizes and color intensity are $\propto$ to their correlation,
        red denotes positive correlation while blue denotes negative.
        thresholding |correlation| < {correlation_threshold}  
        """
                 )

    plt.show()
    return None


def plot_correlation_graph_of_stochastic_shapley_values_from_gpshap(gpshap: GPSHAP,
                                                                    data_id: int,
                                                                    feature_names: list[str],
                                                                    scale: float,
                                                                    correlation_threshold: float
                                                                    ):
    covariance_matrix = gpshap.compute_cross_covariance_for_query_i_j(data_id, data_id) * scale ** 2
    correlation_matrix = compute_correlation_matrix(covariance_matrix)

    return plot_correlation_graph_of_stachastic_shapley_values_from_correlation_matrix(
        correlation_matrix, feature_names, correlation_threshold
    )

    # return pos


def compute_correlation_matrix(covariance_matrix: FloatTensor):
    variance = torch.diag(covariance_matrix)
    std_inverse = 1 / variance.sqrt()
    std_inverse.unsqueeze_(dim=1)

    return ((std_inverse * covariance_matrix).t() * std_inverse).t()
