import numpy as np
import cvxpy as cp

# Given the adjacency matrix A of an undirected graph,
# it returns the transition matrix of the fastest mixing discrete-time Markov chain
def fastest_mixing_discrete(A):
    n = A.shape[0]
    
    P = cp.Variable((n,n), symmetric=True) # transition matrix
    s = cp.Variable(1) # eigenvalue
    
    ones_vec = np.ones((n,1))
    ones_mat = np.ones((n,n))
    
    obj = cp.Minimize(s)
    
    # Transition matrix constraints
    constraints = [ P >= 0 ]
    constraints += [ P @ ones_vec == ones_vec ]
    # edge constraints
    for i in range(n):
        for j in range(i+1,n):
            if A[i,j] == 0:
                constraints += [ P[i,j] == 0 ]
                constraints += [ P[j,i] == 0 ]
    # eigenvalue constraints
    constraints += [ s * np.eye(n) >> P - (1/n)*ones_mat ]
    constraints += [ P - (1/n)*ones_mat >> - s * np.eye(n) ]
    
    prob = cp.Problem(obj, constraints)
    prob.solve()
    
    return P.value

# Given the adjacency matrix A of an undirected graph,
# it returns the edge-weights of the fastest mixing continuous Markov chain
# with average leave-rate one
def fastest_mixing_continuous(A):
    n = A.shape[0]
    
    W = cp.Variable((n,n), symmetric=True) # weights 
    L = cp.Variable((n,n), symmetric=True) # Laplacian
    s = cp.Variable(1) # eigenvalue
    
    ones_mat = np.ones((n,n))
    
    obj = cp.Maximize(s)
    
    # Transition matrix constraints
    constraints = [ W >= 0 ]
    constraints += [ cp.sum(W) == n ]
    # edge constraints
    for i in range(n):
        for j in range(i,n):
            if A[i,j] == 0:
                constraints += [ W[i,j] == 0 ]
                constraints += [ W[j,i] == 0 ]
    # Laplacian constraints
    for i in range(n):
        for j in range(i+1,n):
            constraints += [ L[i,j] == -W[i,j] ]
            constraints += [ L[j,i] == -W[j,i] ]
    for i in range(n):
        constraints += [ L[i,i] == cp.sum([W[i,j] for j in range(n)]) ]
    
    # eigenvalue constraint
    constraints += [ L + (1/n)*ones_mat >> s * np.eye(n) ]
    
    prob = cp.Problem(obj, constraints)
    prob.solve()
    
    return W.value
    