import torch
import numpy as np
import root_finding
from tqdm import tqdm
import time

OPTIMIZERS = {'SGD': torch.optim.SGD,
              'Adam': torch.optim.Adam, 
              'NAdam': torch.optim.NAdam,
              'LBFGS': torch.optim.LBFGS}

def entropy(P: torch.Tensor,
            log: bool = False,
            ax: int = -1):
    """
        Returns the entropy of P along axis ax, supports log domain input.

        Parameters
        ----------
        P: array (n,n)
            input data
        log: bool
            if True, assumes that P is in log domain
        ax: int
            axis on which entropy is computed
    """
    if log:
        return -(torch.exp(P)*(P-1)).sum(ax)
    else:
        return -(P*(torch.log(P)-1)).sum(ax)


def KL(P: torch.Tensor,
        K: torch.Tensor,
        log: bool=False):
    """
        Returns the Kullback-Leibler divergence between P and K, supports log domain input for both matrices.

        Parameters
        ----------
        P: array
            input data
        K: array
            input data
        log: bool
            if True, assumes that P and K are in log domain
    """
    if log:
        return (torch.exp(P) * (P - K - 1)).sum()
    else:
        return (P * (torch.log(P/K) - 1)).sum()


# ----- Entropic Affinity -----


def log_Pe(C: torch.Tensor,
           eps: torch.Tensor):
    """
        Returns the log of the directed affinity matrix of SNE.

        Parameters
        ----------
        C: array (n,n) 
            distance matrix
        eps: array (n)
            kernel bandwidths vector
    """
    log_P = - C / (eps[:, None])
    return log_P - torch.logsumexp(log_P, -1, keepdim=True)


def log_e_affinity(C: torch.Tensor,
               perp: int = 30,
               proj_KL: bool = False,
               tol: float = 1e-5,
               max_iter: int = 1000,
               verbose: bool = True,
               begin: torch.Tensor = None,
               end: torch.Tensor = None):
    """
        Performs a search to solve the dual problem of entropic affinities.
        Returns the entropic affinity matrix in log domain.
        Parameters
        ----------
        C: array (n,p) 
            distance matrix
        perp: int 
            value of the perplexity parameter K
        proj_KL: bool
            specifies if entropic affinity (False) or KL projection (True)
        tol: float
            precision threshold at which the algorithm stops
        max_iter: int
            maximum iterations of search
        verbose: bool
            if True, prints current mean and std entropy values and current bounds 
    """
    target_entropy = np.log(perp) + 1
    n = C.shape[0]
        
    def f(eps):
        return entropy(log_Pe(C, eps), log=True) - target_entropy

    if proj_KL:
        eps_star, begin, end = root_finding.false_position_lower_bound_one(f=f, n=n, begin=begin, end=end, tol=tol, max_iter=max_iter, verbose=verbose)
    else:
        eps_star, begin, end  = root_finding.false_position(f=f, n=n, begin=begin, end=end, tol=tol, max_iter=max_iter, verbose=verbose)
    
    return log_Pe(C, eps_star), begin, end


def SNE_affinity(C: torch.Tensor,
               perp: int = 30,
               tol: float = 1e-5,
               max_iter: int = 10000,
               verbose: bool = True):
    """
        Returns the affinity matrix of SNE / t-SNE.
        Parameters
        ----------
        C: array (n,n) 
            cost matrix
        perp: int 
            value of the perplexity parameter K
        tol: float
            precision threshold at which the algorithm stops
        max_iter: int
            maximum iterations of search
        verbose: bool
            if True, prints current mean and std entropy values and current bounds 
    """

    if verbose:
        print('---------- Computing the Affinity Matrix ----------')

    log_P, _, _ = log_e_affinity(C=C, perp=perp, tol=tol, max_iter=max_iter, verbose=verbose)
    log_P_SNE = torch.logsumexp(torch.stack([log_P,log_P.T], 0), 0, keepdim=False) - np.log(2)
    return torch.exp(log_P_SNE)


# ----- Symmetric Entropic Affinity -----


def se_affinity(C: torch.Tensor,
                perp: int,
                scaling: float = 1e0,
                tol: int = 1e-5,
                max_iter: int = 2000,
                verbose: bool = True,
                tolog: bool = False,
                max_iter_eaffinity: int = 100):
    """
        Performs alternating Bregman projections to compute symmetric entropic affinities.
        Returns the symmetric entropic affinity matrix.
        Parameters
        ----------
        C: array (n,n) 
            symmetric cost matrix
        perp: int 
            value of the perplexity parameter K
        tol: float
            precision threshold at which the algorithm stops
        max_iter: int
            maximum iterations of binary search
        verbose: bool
            if True, prints current mean and std perplexity values and binary search bounds 
        tolog: bool
            if True, log and returns intermediate variables
    """
    n = C.shape[0]
    assert 1 <= perp <= n

    begin = torch.ones(n, dtype=torch.double)
    end = torch.ones(n, dtype=torch.double)

    #Initialize Dykstra variable for entropic constraint
    log_Q = torch.zeros((n,n))

    #Scale the input cost
    log_Ps = - C / scaling

    if verbose:
        print('---------- Computing the Affinity Matrix ----------')

    if tolog:
        log = {}
        log['log_P'] = []
        log['loss'] = []

    pbar = tqdm(range(max_iter), disable = not verbose)
    for k in pbar:
        # Projection KL onto entropic constraint + stochasticity constraint
        log_Ph, begin , end = log_e_affinity(
            C=-log_Ps-log_Q, perp=perp, tol=1e-9, max_iter=max_iter_eaffinity, verbose=False, proj_KL=True, begin=begin, end=end)

        #Update Dykstra variable for entropic constraint
        log_Q += log_Ps - log_Ph

        # Projection KL onto the set of symmetric matrices
        log_Ps = 0.5 * (log_Ph + log_Ph.T)

        if tolog:
            #log['log_P'].append(log_Ph.clone())
            log['loss'].append(KL(log_Ph, -C, log=True).item())

        P_sum = torch.exp(torch.logsumexp(log_Ph, 0, keepdim=False))
        H = entropy(log_Ps, log=True)
        perps = torch.exp(H-1)

        if verbose:
            pbar.set_description(
                f'perps mean : {float(perps.mean().item()): .3e}, '
                f'perps std : {float(perps.std().item()): .3e}, '
                f'marginal sum : {float(P_sum.mean().item()): .3e}, '
                f'marginal std : {float(P_sum.std().item()): .3e}, ')

        if (torch.abs(H - np.log(perp)-1) < tol).all() and (torch.abs(P_sum - torch.ones(n, dtype=torch.double)) < tol).all() and (torch.abs(log_Ph - log_Ph.T) < tol).all():
            if verbose:
                print(f'breaking at iter {k+1}')
            break
        
        if k == max_iter-1 and verbose:
            print('---------- Max iter attained ----------')
            
    if tolog:
        return torch.exp(log_Ps), log
    else:
        return torch.exp(log_Ps)
    

def log_Pse(C: torch.Tensor,
            eps: torch.Tensor,
            mu: torch.Tensor):
    """
        Returns the log of the symmetric entropic affinity matrix with specified parameters epsilon and mu.

        Parameters
        ----------
        C: array (n,n) 
            distance matrix
        eps: array (n)
            symmetric entropic affinity dual variables associated to the entropy constraint
        mu: array (n)
            symmetric entropic affinity dual variables associated to the marginal constraint
    """
    return (mu[:, None] + mu[None, :] - 2*C)/(eps[:, None] + eps[None, :])


# Alternative method to compute symmetric entropic affinities.

def se_affinity_dual_ascent(C: torch.Tensor,
                perp: int,
                lr: float = 1e0,
                tol: int = 1e-3,
                max_iter: int = 10000,
                optimizer: bool = 'Adam',
                verbose: bool = True,
                tolog: bool = False,
                use_scheduler: bool = False):
    """
        Performs dual ascent to compute symmetric entropic affinities.
        Returns the symmetric entropic affinity matrix.

        Parameters
        ----------
        C: array (n,n) 
            symmetric cost matrix
        perp: int 
            value of the perplexity parameter K
        lr: float
            learning rate used for gradient ascent
        tol: float
            precision threshold at which the algorithm stops
        max_iter: int
            maximum iterations of binary search
        optimizer: bool
            specifies which pytorch optimizer to use
        verbose: bool
            if True, prints current mean and std perplexity values and binary search bounds 
        tolog: bool
            if True, log and returns intermediate variables
    """
    st = time.time()
    n = C.shape[0]
    assert 1 <= perp <= n
    target_entropy = np.log(perp) + 1
    eps = torch.ones(n, dtype=torch.double)
    mu = torch.zeros(n, dtype=torch.double)
    log_P = log_Pse(C, eps, mu)

    optimizer = OPTIMIZERS[optimizer]([eps, mu], lr=lr)
    if use_scheduler:
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.9)

    if tolog:
        log = {}
        log['eps'] = [eps.clone().detach()]
        log['mu'] = [mu.clone().detach()]
        log['log_P'] = []
        log['log_P'].append(log_P)
        log['loss'] = []
        log['time_to_k'] = []
        log['time_to_k'].append(0)

    if verbose:
        print('---------- Computing the Affinity Matrix ----------')

    one = torch.ones(n, dtype=torch.double)
    pbar = tqdm(range(max_iter))
    st0 = time.time()
    for k in pbar:
        with torch.no_grad():
            optimizer.zero_grad()
            H = entropy(log_P, log=True)
            eps.grad = H - target_entropy
            P_sum = torch.exp(torch.logsumexp(log_P, -1, keepdim=False))
            mu.grad = P_sum - one
            optimizer.step()
            eps.clamp_(min=0)
            log_P = log_Pse(C, eps, mu)

            if torch.isnan(eps).any() or torch.isnan(mu).any():
                raise Exception(f'NaN in variables at iteration {k}')

            if tolog:
                log['time_to_k'].append(time.time()-st0)
                #log['log_P'].append(log_P.clone())
                log['eps'].append(eps.clone().detach())
                log['mu'].append(mu.clone().detach())
                log['loss'].append(-Lagrangian(C, torch.exp(log_P.clone().detach()),
                                               eps.clone().detach(), mu.clone().detach(), perp).item())

            perps = torch.exp(H-1)
            if verbose:
                pbar.set_description(
                    f'perps mean : {float(perps.mean().item()): .3e}, '
                    f'perps std : {float(perps.std().item()): .3e}, '
                    f'marginal sum : {float(P_sum.mean().item()): .3e}, '
                    f'marginal std : {float(P_sum.std().item()): .3e}, ')

            if (torch.abs(H - np.log(perp)-1) < tol).all() and (torch.abs(P_sum - one) < tol).all():
                if verbose:
                    print(f'breaking at iter {k}')
                break

            if k == max_iter-1 and verbose:
                print('---------- Max iter attained ----------')

            if k>100 and use_scheduler:
                scheduler_loss = (perps.mean() - perp)**2 + perps.std() + (P_sum.mean() - 1)**2 + P_sum.std()
                scheduler.step(scheduler_loss)

    ed = time.time()
    if tolog:
        log['total_time'] = ed - st
        return torch.exp(log_P), log
    else:
        return torch.exp(log_P)



#Stabilization with squared trick

def log_Pse2(C: torch.Tensor,
            eps: torch.Tensor,
            mu: torch.Tensor):
    """
        Returns the log of the symmetric entropic affinity matrix with specified parameters epsilon and mu.

        Parameters
        ----------
        C: array (n,n) 
            distance matrix
        eps: array (n)
            symmetric entropic affinity dual variables associated to the entropy constraint
        mu: array (n)
            symmetric entropic affinity dual variables associated to the marginal constraint
    """
    return (mu[:, None] + mu[None, :] - 2*C)/(eps[:, None]**2 + eps[None, :]**2)


def se_affinity_dual_ascent2(C: torch.Tensor,
                perp: int,
                lr: float = 1e0,
                tol: int = 1e-3,
                max_iter: int = 10000,
                optimizer: bool = 'Adam',
                verbose: bool = True,
                tolog: bool = False,
                rho = 1.0):
    st = time.time()
    n = C.shape[0]
    assert 1 <= perp <= n
    target_entropy = np.log(perp) + 1
    eps = rho*torch.ones(n, dtype=torch.double)
    mu = torch.zeros(n, dtype=torch.double)
    log_P = log_Pse2(C, eps, mu)
    optimizer_name = optimizer
    optimizer = OPTIMIZERS[optimizer]([eps, mu], lr=lr)

    if tolog:
        log = {}
        log['log_P'] = []
        log['log_P'].append(log_P)
        log['eps'] = [eps.clone().detach()]
        log['mu'] = [mu.clone().detach()]
        log['loss'] = []
        log['time_to_k'] = []
        log['time_to_k'].append(0)

    if verbose:
        print('---------- Computing the Affinity Matrix ----------')

    one = torch.ones(n, dtype=torch.double)

    pbar = tqdm(range(max_iter))
    st0 = time.time()
    for k in pbar:
        with torch.no_grad():
            if optimizer_name == 'LBFGS':
                def closure():
                    optimizer.zero_grad()
                    H = entropy(log_P, log=True)
                    eps.grad = 2*eps.clone().detach()*(H - target_entropy)
                    P_sum = torch.exp(torch.logsumexp(log_P, -1, keepdim=False))
                    mu.grad = P_sum - one
                    return -Lagrangian(C, log_P, eps, mu, perp=perp)
                optimizer.step(closure)
            else:
                optimizer.zero_grad()
                H = entropy(log_P, log=True)
                eps.grad = 2*eps.clone().detach()*(H - target_entropy) #do not forget to multiply by the jacobian here
                P_sum = torch.exp(torch.logsumexp(log_P, -1, keepdim=False))
                mu.grad = P_sum - one
                optimizer.step()

            log_P = log_Pse2(C, eps, mu)
            if optimizer_name == 'LBFGS':
                H = entropy(log_P, log=True)
                P_sum = torch.exp(torch.logsumexp(log_P, -1, keepdim=False))

            if torch.isnan(eps).any() or torch.isnan(mu).any():
                print('Iteration = {}'.format(k))
                print('Eps is nan : {}'.format(torch.isnan(eps).any()))
                print('Mu is nan : {}'.format(torch.isnan(mu).any()))
                print('Entropy', H)
                print('P_sum', P_sum)
                raise Exception(f'NaN in variables at iteration {k}')

            if tolog:
                log['time_to_k'].append(time.time()-st0)
                log['eps'].append(eps.clone().detach())
                log['mu'].append(mu.clone().detach())
                # log['loss'].append(-Lagrangian(C, torch.exp(log_P.clone().detach()),
                #                                eps.clone().detach()**2, mu.clone().detach(), perp).item())

            perps = torch.exp(H-1)
            if verbose:
                pbar.set_description(
                    f'perps mean : {float(perps.mean().item()): .3e}, '
                    f'perps std : {float(perps.std().item()): .3e}, '
                    f'marginal sum : {float(P_sum.mean().item()): .3e}, '
                    f'marginal std : {float(P_sum.std().item()): .3e}, ')

            if (torch.abs(H - np.log(perp)-1) < tol).all() and (torch.abs(P_sum - one) < tol).all():
                if verbose:
                    print(f'breaking at iter {k}')
                break

            if k == max_iter-1 and verbose:
                print('---------- Max iter attained ----------')

    ed = time.time()
    if tolog:
        log['total_time'] = ed - st
        return torch.exp(log_P), log
    else:
        return torch.exp(log_P)
    
def Lagrangian(C, log_P, eps, mu, perp=30):
    one = torch.ones(C.shape[0], dtype=torch.double)
    target_entropy = np.log(perp) + 1
    HP = entropy(log_P, log=True, ax=1)
    return torch.exp(torch.logsumexp(log_P + torch.log(C), (0,1), keepdim=False)) + torch.inner(eps, (target_entropy - HP)) + torch.inner(mu, (one - torch.exp(torch.logsumexp(log_P, -1, keepdim=False))))
