import pickle

from library.tasks.common import TargetSphere, Cue, OU, DiscreteTargetSequence
from library.tasks.reach import PerturbedHand, HandVelocity
from library.helper.functions import sdeint_aaeh
from library.plotting import plot_2d, plot_3d, utils
from library.sde.systems import SDE
from library.rnn_architectures.base_rnn import RNN

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import torchsde
import time
import library


def input_plotting(ax, hand_movement, condition):

    cmap = matplotlib.cm.get_cmap('hsv')
    plot_2d.trajectories_gradient(ax, hand_movement, condition=condition, zorder=8)

    time_steps = int(parameters['execution_time_steps']/4)

    p, a = positions.cpu().numpy(), angles.cpu().numpy()
    for i in range(parameters['number_targets_test']):
        ax.scatter(p[i,0], p[i,1], color=cmap(a[i]/(np.pi*2)), s=100, zorder=7)

    for i in range(parameters['batch_size_test']):
        angle = condition[i]

        ax.scatter(hand_movement[::time_steps,i,0], hand_movement[::time_steps,i,1],
                   edgecolor=cmap(angle), facecolor = (1,1,1), s=10, alpha=1, zorder=11, marker='o')
        ax.scatter(hand_movement[-1,i,0], hand_movement[-1,i,1],
                   edgecolor=cmap(angle), facecolor = (1,1,1), s=10, alpha=1, zorder=11, marker='o')

    utils.remove_axes(ax)
    ax.set_xlim(-1.25, 1.25), ax.set_ylim(-1.25, 1.25)
    ax.set_xticks([]), ax.set_yticks([])


def tnp(x): return x.detach().cpu().numpy()


def train():

    rows, columns = 2, 5
    fig = plt.figure(figsize=(columns*4+1.1,rows*4), constrained_layout=True, dpi=60)
    gs = fig.add_gridspec(ncols=columns,nrows=rows)
    axs_3d = [[0,0]]
    axs_ignored = []
    axs = [[fig.add_subplot(gs[i,j], projection=('3d' if [i,j] in axs_3d else None)) if [i,j] not in axs_ignored else 
            None for i in range(rows)] for j in range(columns)]

    fig.suptitle('Dir: ' + directory.split('/')[-1])

    losses = []
    Ws = []

    plt.show(block=False)

    params = list(sde.parameters())

    if parameters['optimizer'] == 'Adam':
        optim = torch.optim.Adam(params, lr=parameters['learning_rate'])
    if parameters['optimizer'] == 'SGD':
        optim = torch.optim.SGD(params, lr=parameters['learning_rate'])

    hand_movements_idx = {parameters_perturb['perturbation_iteration'] - 1:  'Baseline',
                          parameters_perturb['perturbation_iteration']: 'Perturbed - early',
                          parameters_perturb['washout_iteration']-1: 'Perturbed - late',
                          parameters_perturb['washout_iteration']: 'Washout - early'}
    hand_movements = {}
    for i in hand_movements_idx.values(): hand_movements[i] = None
    hand_movements['Washout - late'] = None

    data = {'rnn_activity': [], 'condition': [], 'hand_movement': [], 'weights': []}

    for iteration in range(parameters['training_iterations']):

        if parameters_perturb['perturbation_iteration'] <= iteration < parameters_perturb['washout_iteration'] \
                and parameters_perturb['perturb']:
            hand_velocity.coef_perturbation = parameters_perturb['coef_perturbation']
        else:
            hand_velocity.coef_perturbation = 0.0

        target.set_pos(parameters['batch_size'])
        sde.build_parameterization()

        t = time.time()

        cue_time = parameters['preparatory_duration'] + np.random.rand() * parameters['random_duration']
        ts1 = torch.linspace(0, cue_time, parameters['preparatory_time_steps'])
        xs = sdeint_aaeh(sde, sde.get_initial_state(parameters['batch_size']), ts1)

        # ==== Intermediate stuff ====
        cut_xs1 = sde.cut_states(xs)
        with torch.no_grad():
            if parameters['cued']: cut_xs1[-1].data -= 1

            t_pos = target.pos.clone()
            target.pos *= 0
            cut_xs1[2].data *= 0
        xs = torch.cat(cut_xs1, dim=-1)
        # ===========================

        end_time = parameters['execution_duration']
        ts2 = torch.linspace(0, end_time, parameters['execution_time_steps'])
        xs2 = sdeint_aaeh(sde, xs[-1], ts2)

        cut_xs2 = sde.cut_states(xs2)

        l1 = torch.mean(cut_xs1[1] ** 2)
        if parameters['integrated_loss']:
            l2 = torch.mean((t_pos.unsqueeze(0)*torch.linspace(0,1,parameters['execution_time_steps'],
                                            device=device).unsqueeze(-1).unsqueeze(-1) - cut_xs2[1]) ** 2)#[-1][-1]
        else:
            l2 = torch.mean((t_pos.unsqueeze(0)[-1] - cut_xs2[1][-1]) ** 2)#
        l3 = (cut_xs2[-2][-1]**2).mean()
        l = l1*parameters['preparatory_penalty'] + l2 + l3*parameters['end_velocity_penalty']#*10**-1
        l += (net.activation(cut_xs1[0])**2).mean()*10**-2

        l.backward()
        sde.backward_parameterization()

        print(iteration, l1.item(), l2.item(), l3.item(), time.time()-t)

        if iteration%parameters['test_freq'] == 0 or parameters_perturb['perturb']:
            with torch.no_grad():

                for i in axs:
                    for j in i: j.cla()

                condition = angles[target_test.sequence_ids[:,0]].cpu().numpy() / (np.pi*2)
                target_test.pos = t_pos_test.clone()

                cue_time = parameters['preparatory_duration'] + 0.5*parameters['random_duration']
                ts1 = torch.linspace(0, cue_time, parameters['preparatory_time_steps'])
                xs = sdeint_aaeh(sde_test, sde_test.get_initial_state(parameters['batch_size_test']), ts1, bm=bm_test)#

                # ==== Intermediate stuff ====
                cut_xs1 = sde_test.cut_states(xs)
                with torch.no_grad():
                    if parameters['cued']: cut_xs1[-1].data -= 1

                    target_test.pos *= 0
                    cut_xs1[2].data *= 0
                xs = torch.cat(cut_xs1, dim=-1)
                # ===========================

                end_time = parameters['execution_duration']
                ts2 = torch.linspace(0, end_time, parameters['execution_time_steps'])
                xs2 = sdeint_aaeh(sde_test, xs[-1], ts2, bm=bm_test)#

                cut_xs2 = sde_test.cut_states(xs2)

                l1 = torch.mean(cut_xs1[1] ** 2)
                if parameters['integrated_loss']:
                    l2 = torch.mean((t_pos_test.unsqueeze(0) * torch.linspace(0, 1, parameters['execution_time_steps'],
                                     device=device).unsqueeze(-1).unsqueeze(-1) - cut_xs2[1]) ** 2)  # [-1][-1]
                else:
                    l2 = torch.mean((t_pos_test.unsqueeze(0)[-1] - cut_xs2[1][-1]) ** 2)  #
                l = l2  # + l3

                xs = torch.cat([xs,xs2[1:]], dim=0)
                cut_xs = sde_test.cut_states(xs)

                rnn_trajectories = net.activation(cut_xs[0]).detach().cpu().numpy()
                ts = torch.cat([ts1, ts2[1:]+ts1[-1]]).numpy()

                if iteration in hand_movements_idx.keys() and parameters_perturb['perturb']:
                    hand_movements[hand_movements_idx[iteration]] = {'movement': cut_xs2[1].detach().cpu().numpy(),
                                                                     'condition': condition}

                # ============PCA============
                axs[0][0].set_title('PCA')
                U, S, V = torch.pca_lowrank(net.activation(cut_xs2[0]).reshape(-1, parameters['rec_dim']), q=6)
                rnn_trajectories_pca = (net.activation(cut_xs2[0]) @ V[:,:3]).cpu().numpy()
                plot_3d.trajectories_gradient_shadow(axs[0][0], rnn_trajectories_pca, condition, zorder=5)
                cmap = matplotlib.cm.get_cmap('hsv')
                for i in np.unique(condition):
                    temp = np.median(rnn_trajectories_pca[0, condition == i], axis=0)
                    axs[0][0].scatter(temp[0], temp[1], temp[2], color=cmap(i), s=40, zorder=11)
                axs[0][0].computed_zorder = False

                # ============Single Neurons============
                axs[1][0].set_title('Single neurons')
                plot_2d.trajectories_over_time(axs[1][0], ts, rnn_trajectories[:,:5,:5], condition[:5], alpha=0.5)
                axs[1][0].axvline(ts1[-1], color='grey', linestyle='--')
                axs[1][0].set_xlabel('Time')
                axs[1][0].set_ylim(-1,1)

                # ============Heatmap Sorted Activity============
                axs[2][0].set_title('Sorted by peak activity, 1 trial')
                im = plot_2d.sorted_activity(axs[2][0], rnn_trajectories[:,0], ts)
                axs[2][0].axvline(ts1[-1], color='grey', linestyle='--')
                axs[2][0].set_xlabel('Time'), axs[2][0].set_ylabel('Neuron')
                if iteration == 0: plt.colorbar(im, ax=axs[2][0], orientation='vertical')

                # ============Hand movement============
                hand_movements['Washout - late'] = {'movement': cut_xs2[1].detach().cpu().numpy(), 'condition': condition}
                for i, k in enumerate(hand_movements.keys()):
                    if hand_movements[k] is not None:
                        input_plotting(axs[i][1], hand_movements[k]['movement'], hand_movements[k]['condition'])
                    axs[i][1].set_title(k)

                # ============Loss============
                axs[4][0].set_title('MSE hand to target')
                losses.append(l.item())
                axs[4][0].plot(losses, color='black', label=parameters['beta'])
                axs[4][0].set_xlabel('Iteration'), axs[4][0].set_ylabel('MSE')
                utils.set_bottom_axis(axs[4][0])
                if parameters_perturb['perturb']:
                    axs[4][0].axvline(parameters_perturb['perturbation_iteration'], color=(0.9,0.9,0.9), linestyle='--')
                    axs[4][0].axvline(parameters_perturb['washout_iteration'], color=(0.9,0.9,0.9), linestyle='--')

                # ============Eigenspectrum============
                axs[3][0].set_title('Eigenspectrum')
                plot_2d.eigenspectrum(axs[3][0], net.W.detach().cpu().numpy())
                utils.set_set_equal_lim(axs[3][0])

                plt.pause(0.1)
                plt.draw()
                plt.savefig(directory+'/'+directory.split('/')[-1]+'.pdf')

                torch.save(sde.state_dict(), directory+'/model.pt')

                with open(directory+'/losses.txt', 'a') as f:
                    f.write(str(iteration)+' '+str(l.item())+'\n')

                if parameters_perturb['perturb']:

                    perm = np.random.permutation(parameters['batch_size_test'])

                    data['rnn_activity'].append(rnn_trajectories.transpose(1,0,2)[perm].transpose(1,0,2)[:, :parameters_perturb['batch_save']])
                    data['weights'].append(tnp(net.W.clone()))
                    data['condition'].append(condition[perm][:parameters_perturb['batch_save']])
                    data['hand_movement'].append(tnp(cut_xs2[1]).transpose(1,0,2)[perm].transpose(1,0,2)[:, :parameters_perturb['batch_save']])

        if iteration%parameters['gradient_accumulation']==parameters['gradient_accumulation']-1 and optimize:
            if parameters_perturb['perturbation_iteration'] <= iteration or not parameters_perturb['perturb']:
                optim.step()
                Ws.append(net.W.numpy(force=True))
            optim.zero_grad()

    if parameters_perturb['perturb']:
        for i in data.keys():
            data[i] = np.stack(data[i])
        data['time'] = ts
        data['epochs'] = {'perturbation': parameters_perturb['perturbation_iteration'],
                          'washout': parameters_perturb['washout_iteration']}
        data['additional_information'] = parameters
        data['additional_information_perturb'] = parameters_perturb
        with open(directory+'/data.pkl', 'wb') as f:
            pickle.dump(data, f)
    else:
        with open(directory+'/W.pkl', 'wb') as f:
            pickle.dump(np.stack(Ws), f)

if __name__=='__main__':

    device = ('cuda' if torch.cuda.is_available() else 'cpu')

    utils.set_font()

    optimize = True

    load_directory = '.'

    directory = library.helper.functions.make_directory('task_runs')
    parameters_perturb = library.helper.functions.load_yaml('.', directory, '/perturbation_parameters.yaml')
    print()
    if parameters_perturb['perturb']:
        load_directory = parameters_perturb['load_directory']
    parameters = library.helper.functions.load_yaml(load_directory, directory)
    print()

    if parameters_perturb['perturb']:
        for i in parameters_perturb.keys():
            if i in parameters.keys():
                parameters[i] = parameters_perturb[i]

    np.random.seed(parameters['seed']), torch.manual_seed(parameters['seed'])
 
    hand = PerturbedHand(2, parameters['init_std'], parameters['noise_hand'], rotation=0.0, device=device)
    target = TargetSphere(parameters['drift_target'], noise=parameters['noise_target'], device=device)

    noise_obj = OU(2,parameters['ou_drift'], parameters['ou_noise'], device=device)
    cue_obj = Cue(noise=0.0, device=device)

    in_dims = [2,2,2,1] if parameters['hand_feedback'] else [2,1]

    net = RNN(parameters['rec_dim'], in_dims, #init='zero',
                        noise=parameters['noise'], time_constant=parameters['time_constant'], device=device)
    with torch.no_grad(): net.W *= parameters['weights_init_std']

    decoder = nn.Linear(parameters['rec_dim'], 2, bias=False, device=device)
    decoder = nn.Sequential(nn.Tanh(), decoder)

    with torch.no_grad():
        net.W_in[0] *= parameters['in_weights_init_std']
        decoder[1].weight *= parameters['out_weights_init_std']

    graph = [[1,1 if parameters['hand_feedback'] else 0,1,0,1 if parameters['hand_feedback'] else 0,1],
            [0,0,0,0,1,0],
            [0,0,1,0,0,0],
            [0,0,0,1,0,0],
            [decoder,1,0,1,0,0],
            [0,0,0,0,0,1]]

    hand_velocity = HandVelocity(2, parameters['init_std'], parameters['noise_hand'], coef_perturbation=0.0, device=device)
    sde = SDE([net, hand, target, noise_obj, hand_velocity, cue_obj], graph)

    angles = torch.arange(0, np.pi*2, np.pi*2/parameters['number_targets_test'], device=device)
    positions = torch.stack([torch.cos(angles), torch.sin(angles)], dim=-1)
    target_test = DiscreteTargetSequence(positions, 1, parameters['drift_target'], zero_start=False,
                                         noise=parameters['noise_target'], device=device)

    sde_test = SDE([net, hand, target_test, noise_obj, hand_velocity, cue_obj], graph)
    target_test.set_pos(parameters['batch_size_test'])
    t_pos_test = target_test.pos.clone()

    max_duration = parameters['preparatory_duration'] + parameters['random_duration'] + parameters['execution_duration']
    bm_test = torchsde.BrownianInterval(t0=0, t1=max_duration, size=(parameters['batch_size_test'], sde_test.dim), device='cuda')

    if load_directory != '.': sde.load_state_dict(torch.load(load_directory+'/model.pt'))

    if not parameters['optimize_input_map']: net.W_in.requires_grad_(False)
    if not parameters['optimize_decoder']: decoder.requires_grad_(False)

    train()
