import numpy as np
import torch
import dgl
import orcastr
from tqdm import tqdm
from scipy.sparse import csr_matrix

from graph_tool import Graph as GTGraph
from graph_tool.stats import remove_self_loops, remove_parallel_edges
from graph_tool.centrality import closeness, betweenness, pagerank
from graph_tool.inference import minimize_blockmodel_dl

from graphrole import RecursiveFeatureExtractor, RoleExtractor


def dgl_to_gt(dgl_graph):
    edge_list = torch.stack(dgl_graph.edges()).T.cpu().numpy()

    gt_graph = GTGraph(directed=False)
    gt_graph.add_edge_list(edge_list)
    remove_self_loops(gt_graph)
    remove_parallel_edges(gt_graph)

    return gt_graph


def compute_centrality_measures(graph):
    graph = dgl_to_gt(graph)

    print('Computing harmonic closeness...')
    closeness_values = closeness(graph, harmonic=True)

    print('Computing betweenness...')
    betweenness_values, _ = betweenness(graph)

    print('Computing PageRank...')
    pagerank_values = pagerank(graph)

    centrality_measures = torch.tensor([list(closeness_values), list(betweenness_values), list(pagerank_values)]).T

    return centrality_measures


def get_sbm_groups(graph, num_fits=10):
    graph = dgl_to_gt(graph)

    print(f'The inference algorithm is stochastic and it will be run {num_fits} times...')
    best_state = None
    for _ in tqdm(range(num_fits)):
        state = minimize_blockmodel_dl(graph)
        if best_state is None or state.entropy() < best_state.entropy():
            best_state = state

    groups = list(best_state.get_blocks())
    group_ids = np.unique(groups)
    old_id_to_new_id = {old_id: new_id for new_id, old_id in enumerate(group_ids)}
    groups = [old_id_to_new_id[old_id] for old_id in groups]
    groups = torch.tensor(groups)

    return groups


def compute_rolx_features(graph, max_roles=25):
    graph = dgl.to_networkx(graph).to_undirected()

    feature_extractor = RecursiveFeatureExtractor(graph)
    features = feature_extractor.extract_features()

    role_extractor = RoleExtractor(n_role_range=(2, max_roles))
    role_extractor.extract_role_factors(features)

    rolx_features = torch.tensor(role_extractor.role_percentage.to_numpy()).float()

    return rolx_features


def compute_graphlet_degree_vectors(graph, max_graphlet_size=5):
    if max_graphlet_size not in [4, 5]:
        raise ValueError('max_graphlet_size should be either 4 or 5.')

    source_nodes, target_nodes = graph.edges()
    source_nodes = source_nodes.cpu().numpy()
    target_nodes = target_nodes.cpu().numpy()

    edges = set()
    for u, v in zip(source_nodes, target_nodes):
        if u == v:
            continue
        if u > v:
            u, v = v, u

        edges.add((u, v))

    n = len(graph.nodes())
    m = len(edges)
    lines = [f'{n} {m}\n']
    for u, v in edges:
        lines.append(f'{u} {v}\n')

    orca_input_string = ''.join(lines)
    orca_output_string = orcastr.motif_counts_str('node', max_graphlet_size, orca_input_string)
    graphlet_degree_vectors = [[int(num) for num in line.split()] for line in orca_output_string.splitlines()]
    graphlet_degree_vectors = torch.tensor(graphlet_degree_vectors)

    return graphlet_degree_vectors


def transform_graphlet_degree_vectors_to_binary_features(graphlet_degree_vectors):
    bounds = [
        0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 15, 18, 21, 25, 30, 35, 40, 50, 75, 100, 125, 150, 200, 250, 300, 400,
        500, 750, 1000, 1250, 1500, 2000, 2500, 3500, 5000, 7500, 10000, 15000, 20000, 25000, 32000, 40000, 50000,
        70000, 100000, 150000, 200000, 250000
    ]

    graphlet_features = []
    for i in range(graphlet_degree_vectors.shape[1]):
        counts = graphlet_degree_vectors[:, i]
        for j in range(len(bounds) - 1):
            cur_graphlet_features = (counts >= bounds[j]) & (counts < bounds[j + 1])
            graphlet_features.append(cur_graphlet_features)

        cur_graphlet_features = (counts >= bounds[-1])
        graphlet_features.append(cur_graphlet_features)

    graphlet_features = torch.stack(graphlet_features).T.float()

    return graphlet_features


def dgl_to_csr_matrix(graph):
    graph = dgl.remove_self_loop(graph)

    n = graph.num_nodes()
    row_ids, col_ids = graph.edges()
    values = np.ones_like(row_ids)
    adj_matrix = csr_matrix((values, (row_ids, col_ids)), shape=(n, n))

    return adj_matrix


def compute_spectral_embeddings(graph, dim=128):
    from julia.api import Julia
    jl = Julia(compiled_modules=False)
    from julia import Main
    Main.include('spectral_embeddings.jl')

    adj_matrix = dgl_to_csr_matrix(graph)
    spectral_embeddings = Main.compute_spectral_embeddings(adj_matrix, dim)
    spectral_embeddings = torch.tensor(spectral_embeddings).float()

    return spectral_embeddings
