import matplotlib

from library.rnn_architectures.rank import Rank
from library.helper.functions import sdeint_aaeh
from library.sde.systems import SDE
from library.rnn_architectures.base_rnn import RNN
from library.plotting import utils, plot_2d
from adjoint import *
import rank_plot

import torch
import numpy as np
from matplotlib import pyplot as plt
from torchsde import sdeint
import scipy
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import random

def identity(x): return x

def outer(x, y): return torch.einsum('i,jk->ijk', [x, y])

def mm(x, y): return torch.einsum('ij,jkl->ikl', [x,y])

if __name__ == '__main__':

    device = ('cuda' if torch.cuda.is_available() else 'cpu')

    parameters = {
                    'activation': 'tanh',
                    'rank' : 3,
                    'decoder_dim': 2,
                    'input_dim': 2,

                    'full_rank_weights': True, # Overwrites rank
                    'full_rank_init': True,

                    'dim': 100,
                    'batch_size' : 1,
                    'weight_init_std' : 1.5,
                    'init_state_std' : 1.5,
                    'input_std' : 0.5,

                    'time_steps' : 101,
                    'T' : 5,

                    'steps': 100, # gradient descent steps
                    'learning_rate' : 10**-3,
                    'optimizer' : 'sgd',
                    'sample_size' : 2,
                    'singular_values_plotted' : 20,
                    'tol_sv' : -11, # Singular values bellow 10**-11 are numerical errors
                    'max_tensor_rank' : 6,
                    'rtol' : 10**-8, # Adaptive step SDE solver relative tolerance
                    'atol' : 10**-7, # Adaptive step SDE solver absolute tolerance
                    'tca_max_iter' : 10**5, # Maximum iterations to fit PARAFAC
                    'seed' : 1
                    }

    for k in parameters.keys():
        print(k, ': ', parameters[k])

    if parameters['full_rank_weights'] and parameters['full_rank_init']: parameters['rank'] = parameters['dim']

    activations = {'identity': identity, 'tanh': torch.tanh, 'softplus': torch.nn.functional.softplus,
                   'relu': torch.relu, 'retanh': retanh}
    derivative_activations = {'identity': didentity, 'tanh': dtanh, 'retanh' : dretanh, 'softplus': dsoftplus}

    torch.manual_seed(parameters['seed']), np.random.seed(parameters['seed'])

    Ss = [[] for i in range(8)]
    Ls = [[] for i in range(4)]
    Xs = []
    Ws = []
    Losses = []

    max_dx_forward_backward = 10**-6
    max_dW_forward_backward = 10**-6

    # Loop over random initializations
    for sample_iteration in range(parameters['sample_size']):
        print()
        print('Sample:', sample_iteration)

        # Low rank RNN to build low rank weights if needed
        net2 = Rank(parameters['dim'], parameters['rank'], device=device).double()

        # RNN to optimize
        net = RNN(parameters['dim'], in_dims=(parameters['input_dim'],), init='optimized',
                  activation=activations[parameters['activation']], device=device).double()

        net.W_in.requires_grad_(False) # Here we only optimize W

        if not parameters['full_rank_weights']:
            with torch.no_grad(): net.W.copy_(net2.construct_weight().detach())
        with torch.no_grad(): net.W *= parameters['weight_init_std']

        # Define inputs which are low-d LDS
        inputs = LDSinput(parameters['input_dim'], std=parameters['input_std'], device=device)
        inputs.set_initial_state(parameters['batch_size'])

        # Main SDE, which is here just an ODE, see SDE docstring
        sde = SDE([net, inputs], [[1, FunctionToModule(net.activation)],
                                  [0, 1]])
        sde.build_parameterization()

        #========= Adjoint ==========

        # State adjoint
        adj = AdjointRNN(net, derivative_activations[parameters['activation']], device=device).double()

        # Parameter adjoint
        adjW = AdjointW(parameters['dim'] ** 2, device=device).double()

        # To evaluate the RNN and inputs backward in time
        backward_rnn = BackwardDynamicalSystem(net, device=device).double()
        backward_inputs = BackwardDynamicalSystem(inputs, device=device).double()

        graph = [[1, FunctionToModule(net.activation), 0, 0],
                 [0, 1, 0, 0],
                 [1, 0, 1, 0],
                 [FunctionToModule(net.activation), 0, 1, 0]]
        sde_backward = SDE([backward_rnn, backward_inputs, adj, adjW], graph)
        #=====================

        # Here the decoder and target inputs are fixed
        D = torch.randn(parameters['decoder_dim'], parameters['dim'], device=device, dtype=torch.double)/np.sqrt(parameters['dim'])
        y = (torch.rand(parameters['batch_size'], parameters['decoder_dim'], device=device, dtype=torch.double)*2-1)

        # The initial state is also fixed, but potentially low rank
        if parameters['full_rank_init']:
            init = (torch.rand_like(net2.get_initial_state(parameters['batch_size']))*2-1)*parameters['init_state_std']
        else:
            init = net2.get_initial_state(parameters['batch_size']).detach()*parameters['init_state_std']
        init = torch.cat([init, inputs.get_initial_state(parameters['batch_size'])], dim=-1).detach()

        ts = torch.linspace(0, parameters['T'], parameters['time_steps'], device=device, dtype=torch.double)

        # Optimizers
        if parameters['optimizer'] == 'sgd':
            optim = torch.optim.SGD(net.parameters(), lr=parameters['learning_rate'])
        elif parameters['optimizer'] == 'adam':
            optim = torch.optim.Adam(net.parameters(), lr=parameters['learning_rate'])

        W_0 = net.W.numpy(force=True).copy()

        Ws.append([W_0]) # Weights over learning
        Losses.append([]) # Loss over learning
        Xs.append([]) # Activity over learning

        for i in range(parameters['steps']):
            print(i, end=',' if i<parameters['steps']-1 else '\n')

            xs = sdeint(sde, init, ts, adaptive=True, method='euler_heun',
                        rtol=parameters['rtol'], atol=parameters['atol'], dt=10**-4, dt_min=10**-10)
            cut_xs = sde.cut_states(xs)
            x = cut_xs[0]

            l = torch.mean((net.activation(x[-1]) @ D.T - y) ** 2)

            # Adjoint computation don't require gradients
            with torch.no_grad():

                adj.set_terminal_state(x[-1], D, y)
                adj.terminal_state /= parameters['decoder_dim']*parameters['batch_size']
                backward_rnn.set_terminal_state(x[-1])
                backward_inputs.set_terminal_state(cut_xs[1][-1])
                xs2 = sdeint_aaeh(sde_backward, sde_backward.get_initial_state(parameters['batch_size']), ts,
                                  rtol=parameters['rtol'], atol=parameters['atol'], dt=10**-4, dt_min=10**-10)
                cut_xs2 = sde_backward.cut_states(xs2)

                # Estimate of gradient using adjoint
                gradW_adj = cut_xs2[-1][-1].reshape((-1, parameters['dim'], parameters['dim'])).sum(dim=0)

            optim.zero_grad()
            l.backward()

            # Computing an RNN backward in time can lead to numerical errors, we check for that
            max_dx_forward_backward = max(max_dx_forward_backward, rank_plot.var_exp(cut_xs2[0].flip(dims=(0,)), x))
            max_dW_forward_backward = max(max_dW_forward_backward, rank_plot.var_exp(net.W.grad, gradW_adj))

            optim.step()

            Losses[-1].append(l.item())
            Xs[-1].append(net.activation(x).numpy(force=True).copy())

            Ws[-1].append(net.W.numpy(force=True).copy())

            if i in [0, parameters['steps']-1]: print('iteration:', i, 'L:', l.item())
            if i == 0:
                xs0 = [net.activation(x)[:, 0, :].numpy(force=True).copy(), cut_xs2[2][:, 0, :].numpy(force=True).copy()]

        W_end = net.W.numpy(force=True).copy()

        # The column, row, decoder and input map space of the initialization, only meaningful for low rank RNN
        w_o = scipy.linalg.orth(np.concatenate([net2.W_column.weight[:, :net2.rank].numpy(force=True),
                                                net2.W_row.weight[:, :net2.rank].numpy(force=True),
                                                D.T.numpy(force=True),
                                                net.W_in[0].numpy(force=True)], axis=1))

        x_projected = x.numpy(force=True) @ w_o @ w_o.T
        W_projected = W_end @ w_o @ w_o.T

        v_e_x = 1-np.mean((x.numpy(force=True)-x_projected)**2)/x.numpy(force=True).var()
        v_e_w = 1-np.mean((W_end - W_projected)**2)/W_end.var()

        print('var exp original subspace x:', v_e_x, ' W:', v_e_w)

        dW = (W_end - W_0)
        dW0 = (Ws[-1][1] - Ws[-1][0])/parameters['learning_rate']
        print('dW_0,0:', dW[0,0], ' W_0,0(0):', W_0[0,0], ' W_0,0(end):',  W_end[0,0])

        # Check for numerical errors
        print('max forward backward var exp x:', max_dx_forward_backward, 'dW:', max_dW_forward_backward)

        xs_max = [net.activation(x)[:, 0, :].numpy(force=True).copy(), cut_xs2[2][:, 0, :].numpy(force=True).copy()]

        for i, w in enumerate([W_0, W_end, dW, dW0] + xs0 + xs_max):
            U, S, V = np.linalg.svd(w, full_matrices=False)

            Ss[i].append(S)

        for i, w in enumerate([W_0, W_end, dW, dW0]):
            L, V = np.linalg.eig(w)
            Ls[i].append(L)

    #==================================================

    Ss = np.array(Ss)
    Ls = np.array(Ls)
    Ws = np.array(Ws)
    Losses = np.array(Losses)
    Xs = np.array(Xs)

    temp = np.min(np.stack([Ss[4,:,0:1]*Ss[5], Ss[5,:,0:1]*Ss[4]], axis=-1), axis=-1)
    Ss = np.insert(Ss, 4, temp, axis=0)

    Ss = np.log10(Ss)

    #==================== Plotting ====================
    utils.set_font(font_size=16)

    cmap = matplotlib.colormaps['Set2']
    cmap_eig = utils.get_cmap_interpolated(np.array((137, 230, 62))/255, np.array((255, 208, 54))/255, np.array((247, 79, 36))/255)

    inset_size = 1.5

    rows, columns = 1, 7
    axs = rank_plot.get_axs(rows, columns)

    # Loss
    Losses_std, Losses_mean = Losses.std(axis=0), Losses.mean(axis=0)
    rank_plot.loss(axs[0,0], Losses, cmap_eig)

    # Activation
    ax_temp = inset_axes(axs[0,0], width=inset_size,  height=inset_size, loc=1)
    rank_plot.activation(ax_temp, net)

    # Activity
    rank_plot.activity(axs[1,0], ts, Xs, cmap_eig, parameters)

    # Activity over trials
    ax_temp = inset_axes(ax_temp, width=inset_size,  height=inset_size, loc=1)
    rank_plot.activity_over_trials(ax_temp, parameters, ts, cmap_eig, Xs)

    # Weights SV
    labels = ['$W_0$',
              '$W_{'+str(parameters['steps'])+'}$',
              r'$\Delta W_{0:'+str(parameters['steps'])+'}$',
              r'$\nabla_W L^{(0)}$',
              '$min(\sigma^{\phi(\mathbf{x})}_1\sigma^\mathbf{a}_i, \sigma^\mathbf{a}_1\sigma^{\phi(\mathbf{x})}_i)$',

              '$\phi(\mathbf{x}^{(0)})$',
              '$\mathbf{a}_\mathbf{x}^{(0)}$',
              '$\phi(\mathbf{x}^{('+str(parameters['steps'])+')})$',
              '$\mathbf{a}_\mathbf{x}^{('+str(parameters['steps'])+')}$']

    colors  = [
            [0.85, 0.85, 0.85],
            [0.2, 0.2, 0.2],
            [0.5, 0.7, 0.9],
            [0.95, 0.6, 0.95],
            [0.8, 0.1, 0.8],

            [0.95, 0.6, 0.6],
            [0.6, 0.6, 0.95],
            [0.7, 0.1, 0.1],
            [0.1, 0.1, 0.7],
            ]

    linestyles = ['o', '-o', '-o', '-o', '-o',
                  'o', 'o', '-o', '-o',]

    zorders = [1, 0, 0, 0, 0,
               1, 1, 0, 0]

    axs_plot = [axs[4, 0], axs[3, 0]]
    axs_insets = [inset_axes(ax, width=inset_size, height=inset_size, loc=1) for ax in axs_plot]
    ax_id = [0 for i in range(5)] + [1 for i in range(6)]

    rank_plot.singular_values(Ss, parameters, axs_plot, axs_insets, ax_id, colors, labels, zorders, linestyles)

    # Eig
    rank_plot.eig(axs[2,0], Ws, cmap_eig, Ls, labels)

    # Tensor rank
    ax_temp = axs[5,0]
    rank_plot.tensor_rank(axs[5,0], Ws, parameters, cmap)

    # Text
    ax_temp = axs[-1,0]
    text = ''.join([k + ': ' + str(parameters[k]) +'\n' for k in parameters.keys()])
    ax_temp.text(0, 0, text, fontsize=8, va='bottom', ha='left')
    ax_temp.axis('off')

    # Save
    file_name = ''
    for k in list(parameters.keys())[:10]: file_name += k+str(parameters[k])+'_'
    file_name += str(random.randint(10**4, 10**5))
    plt.savefig('./plots/'+file_name+'.pdf')
    plt.show()
