import numpy as np

# Imaginary part of z used for all Stieltjes transforms
eta = 1e-4

# Global tolerance for convergence of fixed-point iterations
tol = 1e-3
maxiter = 2000

def compute_t(L,alpha,beta,z,w,s,b,xgrid,spec=None,dgrid=None):
  if dgrid is None and spec is None:
    raise Exception('Missing input spectrum')
  if dgrid is not None and len(dgrid) != len(xgrid):
    raise Exception('Incorrect representation of input spectrum')
  if len(alpha) != len(beta) or len(alpha) != len(s)+1:
    raise Exception('Incorrect inputs alpha, beta, s')
  # Base case t_0
  if len(alpha) == 1:
    if spec is not None: return np.mean((beta[0]*spec-w)/(alpha[0]*spec-z))
    else: return np.sum(dgrid*(beta[0]*xgrid-w)/(alpha[0]*xgrid-z))*(xgrid[1]-xgrid[0])
  # Recursive call to evaluate t_{k-1}
  ratio = beta[-1]/alpha[-1]
  alphanew = np.copy(alpha[:-1])
  alphanew[-1] += b**2/s[-1]
  znew = z-(1-b**2)/s[-1]
  betanew = beta[:-1]-ratio*alpha[:-1]
  wnew = w-ratio*z
  return ratio + compute_t(L,alphanew,betanew,znew,wnew,s[:-1],b,xgrid,spec=spec,dgrid=dgrid)

def compute_spectrum(L,alpha,zbase,gamma,b,xgrid,spec=None,dgrid=None,verbose=True,seed=123):
  np.random.seed(seed)
  # Compute s_1,...,s_L at values xgrid+i*eta
  sgrid = np.zeros((len(xgrid),L),dtype=complex)
  for (i,x) in reversed(list(enumerate(xgrid))):
    if verbose:
      print('Computing density at x = %f' % x)
    z = x+1j*eta-zbase
    # Initialization for s(z)
    # Linearly interpolate from previous two values, for faster convergence
    if i == len(xgrid)-1: s = np.random.normal(size=L)+1j*np.abs(np.random.normal(size=L))
    elif i == len(xgrid)-2: s = sgrid[i+1,:]
    else: s = sgrid[i+1,:]+(sgrid[i+1,:]-sgrid[i+2,:])
    # Solve fixed-point equation for s_1,...,s_k
    sprev = np.ones(L)*np.inf
    while True:
      converged = False
      for iteration in range(maxiter):
        if np.linalg.norm(s-sprev)/np.linalg.norm(s) < tol:
          converged = True
          break
        sprev = np.copy(s)
        alphanew = np.zeros((L,L+1),dtype=complex)
        znew = np.zeros(L,dtype=complex)
        for l in range(L,0,-1):
          if l == L:
            alphanew[L-1,:] = np.copy(alpha)
            znew[L-1] = z
          else:
            alphanew[l-1,:(l+1)] = np.copy(alphanew[l,:(l+1)])
            znew[l-1] = znew[l]
          alphanew[l-1,l-1] += b**2/sprev[l-1]
          znew[l-1] -= (1-b**2)/sprev[l-1]
          beta = np.zeros(l,dtype=complex)
          beta[l-1] = b**2
          w = b**2-1
          s[l-1] = 1/alphanew[l-1,l]+gamma[l-1]*compute_t( \
              L,alphanew[l-1,:l],beta,znew[l-1],w,sprev[:(l-1)],b,xgrid,spec=spec,dgrid=dgrid)
      if min(np.imag(s)) > 0:
        if verbose and not converged:
          print('  Warning: Did not converge to desired relative error')
        break
      else:
        #if verbose:
        #  print('  Negative imaginary part, resetting')
        s = np.random.normal(size=L)+1j*np.abs(np.random.normal(size=L))
        sprev = np.ones(L)*np.inf
    sgrid[i,:] = s
  # Compute Stieltjes transform
  mgrid = np.zeros(len(xgrid),dtype=complex)
  beta = np.zeros(L+1,dtype=complex)
  w = -1
  for (i,x) in enumerate(xgrid):
    z = x+1j*eta-zbase
    mgrid[i] = compute_t(L,alpha,beta,z,w,sgrid[i,:],b,xgrid,spec=spec,dgrid=dgrid)
  # Get density from Stieltjes inversion
  return 1/np.pi*np.imag(mgrid)

def CK(L,gamma,b,xgrid,spec=None,dgrid=None,verbose=True,seed=123):
  np.random.seed(seed)
  alpha = np.zeros(L+1,dtype=complex)
  zbase = 0
  alpha[-1] = 1
  return compute_spectrum(L,alpha,zbase,gamma,b,xgrid,spec=spec,dgrid=dgrid,verbose=verbose,seed=seed)

def NTK(L,gamma,a,b,xgrid,spec=None,dgrid=None,verbose=True,seed=123):
  alpha = np.zeros(L+1,dtype=complex)
  zbase = 0
  alpha[-1] = 1
  for l in range(L):
    alpha[l] = b**(2*(L-l))
    zbase += a**(L-l)-b**(2*(L-l))
  return compute_spectrum(L,alpha,zbase,gamma,b,xgrid,spec=spec,dgrid=dgrid,verbose=verbose,seed=seed)

