import numpy as np
import networkx as nx
import scipy
from sklearn.metrics import silhouette_samples, silhouette_score
from sklearn.cluster import KMeans
import igraph as ig
import louvain

import sys
from functions_for_DDCSBM import *
from static_SC import *




def dynamic_A(AT, n_clusters, real_classes):
    
    ''' Function to perform community detection using the averaged adjacency matrix
    
    Use : ov, mod = dynamic_A(AT, n_clusters, real_classes)
    
    Output : ov (array of size T) : ov[t] is the overlap of the estimated partition w.r.t the label vector at time t
             mod (array of size T) : mod[t] is the modularity of the estimated partition at time t
    
    Input  : AT (set of sparse matrices) : AT[t] is the sparse adjacency matrix of G_t
           : n_clusters (scalar) : number of communities
           : real_classes (array) : label vector at time T
           
    '''
    
    T = len(AT) # number of time frames
    ov = np.zeros(T) 
    mod = np.zeros(T)
    n = len(real_classes) # number of nodes
    for i in range(T):
        if i == 0:
            A = AT[0]
        else:
            A += AT[i] # averaged matrix A up to time t
        v, X = scipy.sparse.linalg.eigsh(A.astype(float), k = n_clusters, which = 'LA') # computing the largest eigenvalues   
        kmeans = KMeans(n_clusters = n_clusters) # perform kmeans 
        kmeans.fit(X)
        estimated_labels = kmeans.predict(X)
        if np.sum(real_classes[0] == None) == 0:
            ov[i] = overlap(real_classes[i], estimated_labels)
            
        mod[i] = compute_modularity(AT[i], estimated_labels)


    return ov, mod  


##############################################################################################################################


def dynamic_louvain(AT, Label,strength):
    
    ''' Function that finds the partition of the set of adjacency matrices AT according to the dynamic Louvain method
    
    Use : ov, mod, n_clusters = dynamic_louvain(AT, Label, strength)
    
    Output : ov (array of size T) : ov[t] is the overlap of the estimated partition w.r.t the label vector at time t
           : mod (array of size T) : mod[t] is the modularity of the estimated partition at time t
           : n_clusters (scalar) : number of clusters estimated by Louvain algorithm
    
    Input  : AT (set of sparse matrices) : AT[t] is the sparse adjacency matrix of G_t
           : strength (scalar) : weight of temporal connections
           : Label (array) : label vector at time T'''
    
    T = len(AT)
    G = [[]]*T
    n = len(Label[0])
    for i in range(T):
        Gx = nx.from_scipy_sparse_matrix(AT[i].astype(float))
        nx.write_graphml(Gx,'graph.graphml')
        G[i] = ig.read('graph.graphml',format="graphml")
        
    lp = louvain.find_partition_temporal(G, louvain.ModularityVertexPartition, interslice_weight=strength)
    n_clusters = len(np.unique(np.array(lp[0][0])))
    
    ov = np.zeros(T)
    mod = np.zeros(T)
    for i in range(T):
        if np.sum(Label[0] == None) == 0:
            ov[i] = overlap(Label[i], np.array(lp[0])[i])
        mod[i] = compute_modularity(AT[i],np.array(lp[0])[i])

    return ov, mod, n_clusters


##############################################################################


def static_BH(AT, n_clusters, Label):
    
    ''' Function that finds the partition of the set of adjacency matrices AT using the static optimized Bethe-Hessian
    
    Use : ov, mod = static_BH(AT, n_clusters, Label)
    
    Output : ov (array of size T) : ov[t] is the overlap of the estimated partition w.r.t the label vector at time t
            mod (array of size T) : mod[t] is the modularity of the estimated partition at time t
    
    Input  : AT (set of sparse matrices) : AT[t] is the sparse adjacency matrix of G_t
           : n_clusters (scalar) : number of clusters
           : Label (array) : label vector at time T'''
    
    T = len(AT)
    ov = np.zeros(T)
    mod = np.zeros(T)
    
    for i in range(T):
        cluster_st = community_detection(AT[i],real_classes = Label[i], n_clusters = n_clusters)
        if np.sum(Label[0] == None) == 0:
            ov[i] = cluster_st.overlap
            
        mod[i] = cluster_st.modularity
           
    return ov, mod



##############################################################################

def dynamic_B(AT, lamb, eta, n_clusters, real_classes):
    
    ''' Function that performs dynamic spectral clustering according to Ghasemian 2016
    
    Use : ov, mod = dynamic_B(AT, lamb, eta, n_clusters, real_classes)
    
    
    Output : ov (array of size T) : ov[t] is the overlap with respect to the ground truth at time t 
           : mod (array of size T) : mod[t] is the modularity of the estimated partition at time t    
    
    
    Input : AT (sequence of sparse arrays) : AT[t] is the adjacency matrix of the graph G[t]
          : lamb (scalar) : value of lambda
          : eta (scalar) : value of eta
          : n_clusters (scalar) : number of communities k 
          : real_classes (array) : real_classes[t] is the ground truth vector at time t
          
    
    '''
    
    T = len(AT)
    dT = [np.array(np.sum(AT[i], axis = 0))[0] for i in range(T)]
    n = len(dT[0])
    DT = [scipy.sparse.diags(dT[i], offsets = 0) for i in range(T)]
    
    As = [[None]*T for i in range(T)] # spatial adjacency matrix
    for i in range(T):
        for j in range(T):
            if i == j:
                As[i][j] = AT[i]
            else:
                As[i][j] = None

    As = scipy.sparse.bmat(As)

    Ds = [[None]*T for i in range(T)] # spatial degree matrix
    for i in range(T):
        for j in range(T):
            if i == j:
                Ds[i][j] = DT[i]
            else:
                Ds[i][j] = None

    Ds = scipy.sparse.bmat(Ds)
    I = scipy.sparse.diags(np.ones(n), offsets = 0)

    At = [[None]*T for i in range(T)] # temporal adjacency matrix
    for i in range(T):
        for j in range(T):
            if np.abs(i-j) == 1:
                At[i][j] = I

    At = scipy.sparse.bmat(At)
    Dt = scipy.sparse.diags(np.array(np.sum(At, axis = 0))[0], offsets = 0) # temporal degree matrix

    IT = scipy.sparse.diags(np.ones(T*n), offsets = 0)


    B = scipy.sparse.bmat([[As*lamb,-IT*lamb,As*lamb,None],
                          [(Ds-IT)*lamb, None, Ds*lamb, None],
                          [eta*At, None, eta*At, -eta*IT],
                          [eta*Dt, None, eta*(Dt-IT), None]])


    
    v, X = scipy.sparse.linalg.eigs(B, k = n_clusters, which = 'LR')
    Y = [X[i*n:(i+1)*n].real for i in range(T)]    

    kmeans = KMeans(n_clusters = n_clusters) # perform kmeans 
    ov = np.zeros(T)
    mod = np.zeros(T)
    
    for i in range(T):
        kmeans.fit(Y[i])
        estimated_labels = kmeans.predict(Y[i])
        if np.sum(real_classes[0] == None) == 0:
            ov[i] = overlap(real_classes[i], estimated_labels)

        mod[i] = compute_modularity(AT[i], estimated_labels)
    
    return ov, mod
        