import torch
from functools import reduce
from .optimizer import Optimizer

import numpy as np
from scipy.linalg import eigvalsh_tridiagonal

def dPinv(d):

    return 0. if abs(d)<=1e-32 else 1./d


class STBFGS(Optimizer):

    def __init__(self, optimizer=None, beta=1.0, tao=1e-16, m=100, precision=1):
        if optimizer is None:
            raise ValueError("optimizer cannot be None")        
        if beta < 0.0:
            raise ValueError("Invalid STBFGS beta parameter: {}".format(beta))

        self.optimizer = optimizer
        self.beta = beta
        self.tao = tao
        self.m = m
        self.precision = precision
        self.param_groups = self.optimizer.param_groups
        self.state = self.optimizer.state
        self.defaults = self.optimizer.defaults
        self.eig = None
        self.betas = []

        if len(self.param_groups) != 1:
            raise ValueError("Conjugate Anderson doesn't support per-parameter options "
                             "(parameter groups)")

        self._params = self.param_groups[0]['params']
        self._numel_cache = None

        N = self._numel()
        device = self._params[0].device
        if self.precision == 0:
            dtype = self._params[0].dtype
        else:
            dtype = torch.float64
        state = self.state
        state.setdefault('step', 0)
        state.setdefault('mk',-1)
        state.setdefault('p', torch.zeros(N, device=device, dtype=dtype))
        state.setdefault('q', torch.zeros(N, device=device, dtype=dtype))
        state.setdefault('x_prev', torch.zeros(N, dtype=dtype, device=device))
        state.setdefault('r_prev', torch.zeros(N, dtype=dtype, device=device))
        state['etas'] = ([],[],[])
        state['beta_prev'] = state['beta'] = self.beta
        state['pq'] = 0.
        state['pqs'] = []
        state['gamma1'] = 0.
        state['gamma2'] = 0.
        state['phi'] = 0.

    def __setstate__(self, state):
        super(STBFGS, self).__setstate__(state)

    def _numel(self):
        if self._numel_cache is None:
            self._numel_cache = reduce(lambda total, p: total + p.numel(), self._params, 0)
        return self._numel_cache

    def _gather_flat_grad(self):
        views = []
        for p in self._params:
            if p.grad is None:
                view = p.new(p.numel()).zero_()
            elif p.grad.is_sparse:
                view = p.grad.to_dense().view(-1)
            else:
                view = p.grad.view(-1)
            views.append(view)
        return torch.cat(views, 0)

    def _gather_flat_data(self):
        views = []
        for p in self._params:
            views.append(p.data.view(-1))
        return torch.cat(views, 0)

    def _store_data(self, other):
        offset = 0
        for p in self._params:
            numel = p.numel()
            # view as to avoid deprecated pointwise semantics
            p.copy_(other[offset:offset + numel].view_as(p))
            offset += numel
        assert offset == self._numel()

    def _store_grad(self, other):
        offset = 0
        for p in self._params:
            numel = p.numel()
            # view as to avoid deprecated pointwise semantics
            p.grad.copy_(other[offset:offset + numel].view_as(p))
            offset += numel
        assert offset == self._numel()

    def _add_grad(self, step_size, update):
        offset = 0
        for p in self._params:
            numel = p.numel()
            # view as to avoid deprecated pointwise semantics
            p.add_(update[offset:offset + numel].view_as(p), alpha=step_size)
            offset += numel
        assert offset == self._numel()

    def _directional_evaluate(self, closure, x, g, t, d):
        self._add_grad(t, d)
        loss = closure()
        xk = self._gather_flat_data()
        flat_grad = self._gather_flat_grad()
        self._store_data(x)
        self._store_grad(g)
        return loss, xk, flat_grad

    def setfullgrad(self, length):
        self.fullgrad = self._gather_flat_grad().div(length)

    def settmpx(self):
        self.xk = self._gather_flat_data()

    def _get_x_delta(self, Xk, Rk, delta):
        Q, R, G = simpleQR(Rk)
        Xk = Xk.mm(G)
        Rk = Rk.mm(G)
        H_inv = inv(R + delta * inv(R).t() @ (Xk.t() @ Xk))
        Gamma = H_inv @ (Q.t() @ res)
        x_delta = beta * res - (alpha * Xk + alpha * beta * Rk) @ Gamma
        return x_delta

    def geteig(self):
        return self.eig

    #def _reset(self):

    
    @torch.no_grad()
    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        assert len(self.param_groups) == 1

        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        group = self.param_groups[0]
        beta = self.beta

        optimizer = self.optimizer
        N = self._numel()
        device = self._params[0].device
        if self.precision == 1:
            dtype = torch.float64
        else:
            dtype = self._params[0].dtype

        state = self.state

        xk = self._gather_flat_data().to(dtype)
        flat_grad = self._gather_flat_grad()
        weight_decay = group['weight_decay']             
        rk = flat_grad.add(alpha=weight_decay, other=xk).neg().to(dtype)

        p, q = state['p'], state['q']
        r_prev, x_prev = state['r_prev'], state['x_prev']

        cnt = state['step']
        mk = state['mk']
        pq = state['pq']
        pqs = state['pqs']
        beta_prev = state['beta_prev']
        beta = state['beta']
        gamma1 = state['gamma1']
        gamma2 = state['gamma2']
        etas = state['etas']
        restart = False        
        if mk >= 0:
            restart = False
            delta_x = xk-x_prev
            delta_r = rk-r_prev              
            if mk == 0:                
                p.copy_(delta_x)
                q.copy_(delta_r)
            else:                
                zeta = torch.dot(p,delta_r)/pq
                p.copy_(delta_x-p*zeta)
                q.copy_(delta_r-q*zeta)
            
            state['pq'] = pq = torch.dot(p,q)
            pqs.append(abs(pq))
            #if pqs[-1] <= (self.tao)*pqs[0] or mk == self.m or cnt == 10:     # restart at 10th iteration if initialization is not good enough
            if pqs[-1] <= (self.tao)*pqs[0] or mk == self.m:
                restart = True
                
            if not restart:
                                
                if mk >= 1 and mk<=20:
                    phi_prev = state['phi']
                    #phi = gamma2
                    phi = gamma1+gamma2+zeta
                    #print(gamma1+zeta)
                    eta0 = phi_prev/(beta_prev*(1-gamma1))
                    eta1 = (1./beta_prev-phi/beta)/(1-gamma1)
                    eta2 = -1./(beta*(1-gamma1))                                                
                    etas[0].append(eta0.item())
                    etas[1].append(eta1.item())
                    etas[2].append(eta2.item())                
                    Tk = np.diag(etas[1])+np.diag(etas[0][1:],k=1)+np.diag(etas[2][:-1],k=-1)      
                    eig = np.linalg.eigvals(Tk)
                    
                    eig_real = eig.real
                    if min(eig_real) < 0:
                        print('negative eigenvalue! restart!')
                        restart = True
                    else:                                           
                        beta_prev = beta                                            
                        beta = 2./(min(abs(eig))+max(abs(eig)))                                            
                        state['beta_prev'] = beta_prev
                        state['beta'] = beta
                        state['phi'] = phi
                        self.eig = eig
                        #print('beta:',beta)
                    
        
        if restart:            
            state['p'].zero_()
            state['q'].zero_()
            state['etas'] = ([],[],[])
            state['beta_prev'] = state['beta']
            state['pq'] = 0.
            state['pqs'] = []
            state['gamma1'] = 0.
            state['gamma2'] = 0.
            state['phi'] = 0.
            mk = state['mk'] = -1
            print('restart!')
     
    
        x_prev.copy_(xk)
        r_prev.copy_(rk)     
        mk = mk+1   
        
        if mk == 0:
            xk += beta*rk
        else:
            state['gamma1'] = gamma1 = torch.dot(rk,p)/pq
            xk -= p*gamma1
            rk -= q*gamma1          
            xk += beta*rk
            state['gamma2'] = gamma2 = beta*torch.dot(rk,q)/pq
            xk -= p*gamma2            
                        
        self.betas.append(beta)

        self._store_data(xk)
        cnt = cnt+1
        state['step'] = cnt
        state['mk'] = mk
        state['gamma1'] = gamma1
        state['gamma2'] = gamma2

        return loss
