import numpy as np
import torch
import matplotlib.pyplot as plt
import matplotlib.colors as colors
from matplotlib.collections import LineCollection

from loadData import loadDigits
from pack import CANN
from network import Net_2out
from train_singledigit_CANN import train_CANN_loader

import pdb


def test4speed(speed_list, label, output_dict, color4alpha):
    # test the model with different CANN speed
    # plot the trajectory of the output and the target trajectory of the digit for each speed
    # seed_list: a list of CANN speed \alpha*0.04 [0.02,0.04,0.06,0.08,-0.02,-0.04,-0.06,-0.08]
    # label: label
    # model: the model to be tested

    # load the data
    max_len_out = 0
    num_alpha = len(speed_list)
    num_subfig = 5
    num_row = int(np.ceil((num_alpha+1)/num_subfig))
    num_col = num_subfig
    for alpha in speed_list:
        output = output_dict[alpha]
        if len(output) > max_len_out:
            max_len_out = len(output)

    train_loader = train_CANN_loader(batch_size=1)
    _, _, _target = train_loader.digit1_gendata(
                digit_label=label)
    target=_target[0]
    

    plt.figure(figsize=(8, 8))
    for i in range(num_alpha):
        ax = plt.subplot(num_row, num_col, i+1)
        alpha = speed_list[i]
        output = output_dict[alpha]
        # plt.plot(output[:, 0], output[:, 1], 'b')
        
        cmap = plt.cm.rainbow
        colors = cmap(np.linspace(0, 1, max_len_out-1))
        # normalize = colors.Normalize(vmin=0, vmax=1)

        # plt.scatter(output[:, 0], output[:, 1],c=np.arange(len(output))/max_len_out, norm=normalize, cmap='rainbow')
        points = np.array([output[:,0], output[:,1]]).T.reshape(-1, 1, 2)
        segments = np.concatenate([points[:-1], points[1:]], axis=1)

        for i in range(len(segments)):
            plt.plot(segments[i, :, 0], segments[i, :, 1],color=colors[i])

        if label == 2 or label==5 or label==8:
            plt.xlim(-0.3, 0.4)
            plt.ylim(-1, 1)
        elif label == 4:
            plt.xlim(-0.4, 0.25)
            plt.ylim(-0.6, 1)
        else:
            plt.xlim(-0.3, 0.3)
            plt.ylim(-1, 0.8)
        plt.title(r'$\alpha$ = %.2f'%(alpha/0.04))
        plt.xticks([])
        plt.yticks([])
        ax.spines['right'].set_color(color4alpha[alpha])
        ax.spines['top'].set_color(color4alpha[alpha])
        ax.spines['left'].set_color(color4alpha[alpha])
        ax.spines['bottom'].set_color(color4alpha[alpha])

    # plot the target
    ax = plt.subplot(num_row, num_col, num_alpha+1)
    # plt.plot(targt[:, 0], targt[:, 1], 'r')
    colors = cmap(np.linspace(0, 1, len(target)-1))
    points = np.stack((target[:, 0], target[:, 1]), axis=1)
    segments = np.stack((points[:-1],points[1:]), axis=1)
    for i in range(len(target)-1):
        plt.plot(segments[i, :, 0], segments[i, :, 1],color=colors[i])
    # lc = LineCollection(segments, cmap='rainbow', colors=colors)
    # lc.set_linewidth(2)
    # ax.add_collection(lc)
    plt.xticks([])
    plt.yticks([])
    plt.title("Target")
    if label == 2 or label==5 or label==8:
        plt.xlim(-0.3, 0.4)
        plt.ylim(-1, 1)
    elif label == 4:
        plt.xlim(-0.4, 0.25)
        plt.ylim(-0.6, 1)
    else:
        plt.xlim(-0.3, 0.3)
        plt.ylim(-1, 0.8)
    plt.tight_layout()
    # plot a colorbar for the whole figure
    # plt.show()
    fig_name = '0510test4speed_label_%d.pdf'%(label)
    plt.savefig(fig_name, bbox_inches='tight')

    return target, max_len_out

def plot_output(output_dict, target, speed_list, max_len_out, label):
    # plot the x and y coordinate with reversed order and forward order
    # output_dict: a dictionary of output
    # target: the target trajectory of the digit
    # speed_list: a list of CANN speed \alpha*0.04 [0.02,0.04,0.06,0.08,-0.02,-0.04,-0.06,-0.08]



    plt.figure(figsize=(8, 4))
    # set the colormap and norm to correspond to the data for which the colorbar will be used.
    cmap = plt.cm.Spectral
    factor_1 = np.arange(0.1,3.5,0.01)
    factor_2 = -factor_1
    factors = np.concatenate((factor_1,factor_2))
    alphas = factors*0.04
    trials = alphas.shape[0]
    colors = cmap(np.linspace(0, 1, trials))

    ax = plt.subplot(2, 2, 1)
    for alpha in speed_list[5:]:
        color_index = (alpha/0.04-0.1)/0.01+trials/2
        color_index = min(color_index,trials)-1
        output = output_dict[alpha]
        
        plt.plot(np.arange(len(output))/max_len_out, output[:, 0], label=r'$\alpha$ = %.2f'%(alpha/0.04),
                 color=colors[int(color_index)])
        # plt.legend(loc='upper right')
        plt.xlim(0,1)
        plt.ylabel('X')
        plt.yticks([])
        ax.axes.xaxis.set_ticklabels([])
        plt.title('Forward order')
        plt.legend(loc='upper right')


    # plot the output[:,1]
    # speed_list > 0
    ax = plt.subplot(2, 2, 3)
    for alpha in speed_list[5:]:
        color_index = (alpha/0.04-0.1)/0.01+trials/2
        color_index = min(color_index,trials)-1
        output = output_dict[alpha]
        plt.plot(np.arange(len(output))/max_len_out, output[:, 1], label=r'$\alpha$ = %.2f'%(alpha/0.04),
                 color=colors[int(color_index)])
        plt.xlim(0,1)
        plt.xlabel('time')
        plt.ylabel('Y')
        plt.yticks([])
        
    # plot the output[:,0]
    # speed_list < 0
    ax = plt.subplot(2, 2, 2)
    for alpha in speed_list[:5]:
        color_index = trials/2-(-alpha/0.04-0.1)/0.01
        color_index = max(color_index,0)
        output = output_dict[alpha]
        plt.plot(np.arange(len(output))/max_len_out, output[:, 0], label=r'$\alpha$ = %.2f'%(alpha/0.04),
                 color=colors[int(color_index)])
        # plt.legend(loc='upper right')
        plt.xlim(0,1)
        plt.ylabel('X')
        ax.axes.xaxis.set_ticklabels([])
        plt.yticks([])
        plt.legend(loc='upper right')
        plt.title('Reverse order')
    
    # plot the output[:,1]
    # speed_list < 0
    ax = plt.subplot(2, 2, 4)
    for alpha in speed_list[:5]:
        color_index = trials/2-(-alpha/0.04-0.1)/0.01
        color_index = max(color_index,0)
        output = output_dict[alpha]
        plt.plot(np.arange(len(output))/max_len_out, output[:, 1], label=r'$\alpha$ = %.2f'%(alpha/0.04),
                 color=colors[int(color_index)])
        plt.xlim(0,1)
        plt.xlabel('time')
        plt.ylabel('Y')
        plt.yticks([])

    plt.tight_layout()
    plt.savefig('0510test4order_label_%d.pdf'%(label), bbox_inches='tight')
    return output_dict, target

def test4alpha(speed_list, label, model):
    # test the model for different CANN speed \alpha
    # speed_list: a list of CANN speed \alpha [0.02,0.04,0.06,0.08,-0.02,-0.04,-0.06,-0.08]
    # label: the label of the test data
    # model: the trained model

    # test for different CANN speed \alpha
    output_dict = {}
    color4alpha = {}
    max_len_out = 0
    digitDict = loadDigits()
    target = digitDict[label]
    inputData = dict()

    # set the colormap and norm to correspond to the data for which the colorbar will be used.
    cmap = plt.cm.Spectral
    factor_1 = np.arange(0.1,3.5,0.01)
    factor_2 = -factor_1
    factors = np.concatenate((factor_1,factor_2))
    alphas = factors*0.04
    trials = alphas.shape[0]
    colors = cmap(np.linspace(0, 1, trials))

    # get the averaged begin and end point of CANN bump
    # load the CANN data, record the begin and end point of the trajectory
    cann_begin = np.zeros(512)
    cann_end = np.zeros(512)
    num_alpha = 0
    for alpha in np.around(np.arange(-4,4.1,.05),2):
        if alpha>0.1 or alpha<-0.1:
            cann_alpha = alpha*0.04
            inputData[cann_alpha], _ = CANN(alpha=cann_alpha, fano=0.001)
            if alpha < 0:
                cann_begin += inputData[cann_alpha][-1,:]
                cann_end += inputData[cann_alpha][0,:]
            else:
                cann_begin += inputData[cann_alpha][0,:]
                cann_end += inputData[cann_alpha][-1,:]
            num_alpha += 1

    cann_begin = cann_begin/num_alpha
    cann_end = cann_end/num_alpha
    
    # load the model
    model_path = 'CANNmodel_400_label_%d.pth'%(label)
    model.load_state_dict(torch.load(model_path))

    # run the model for different CANN speed \alpha*0.04
    # record the output_dict
    model.eval()
    with torch.no_grad():
        digit_begin = model(torch.from_numpy(cann_begin).float())
        digit_end = model(torch.from_numpy(cann_end).float())
        for alpha in np.around(np.arange(-4,4.1,.05),2):
            if alpha>0.1 or alpha<-0.1:
                input = inputData[alpha*0.04]
                input = torch.from_numpy(input).float()
                len_input = len(input)
                _output = torch.zeros([len_input,2])
                for t in range(len_input):
                    _output[t, :] = model(input[t])
                    if t > len_input*2/3 and alpha < 0:
                        dis = sum((_output[t,:]-digit_begin)**2)
                        if dis<0.0001:
                            break
                    elif t > len_input*2/3 and alpha > 0:
                        dis = sum((_output[t,:]-digit_end)**2)
                        if dis<0.0001:
                            break
                output_dict[alpha*0.04] = _output[:t,:].numpy()
                       
    # plot alpha_0 vs alpha for each speed and plot the output
    len_alpha_0 = len(output_dict[0.04])

    plt.figure(figsize=(6, 6))
    for alpha in np.arange(-4,4.1,0.2):
        if alpha>0.1 or alpha<-0.1:
            alpha = round(alpha,2)
            # print('alpha:', alpha)
            output = output_dict[alpha*0.04]
            len_alpha = len(output)
            if alpha < 0:
                len_alpha = - len_alpha
            plt.scatter(alpha,len_alpha_0/len_alpha, color='white', edgecolors='black', s=80)
    
    plt.plot(np.arange(-4,4,0.1),np.arange(-4,4,0.1), color='black', linestyle='--')

    
    # print('speed_list:', speed_list)
    for alpha in speed_list:
        output = output_dict[alpha]
        len_alpha = len(output)
        if alpha > 0:
            color_index = (alpha/0.04-0.1)/0.01+trials/2
            color_index = min(color_index,trials)-1
        else:
            color_index = trials/2-(-alpha/0.04-0.1)/0.01
            color_index = max(color_index,0)
            len_alpha = - len_alpha
        plt.scatter(alpha/0.04,len_alpha_0/len_alpha, color=colors[int(color_index)],edgecolors='black', s=80)
        color4alpha[alpha] = colors[int(color_index)]
        plt.gca().set_aspect('equal', adjustable='box')
    plt.xlabel(r'$\alpha_0$')
    plt.ylabel(r'$\alpha$')
    plt.savefig('0510test4alpha_label_%d.pdf'%(label), bbox_inches='tight')
    return output_dict, color4alpha


if __name__ == '__main__':

    # set the figure parameters 
    plt.rcParams["pdf.fonttype"] = 42
    plt.rcParams["ps.fonttype"] = 42
    plt.rcParams["font.size"] = 12
    plt.rcParams["font.family"] = "Arial"

    # test for different speed and plot the output
    # CANN speed \alpha*0.04 = [-0.16, -0.08, -0.04, -0.02, -0.01, 0.01, 0.02, 0.04, 0.08, 0.16]
    speed_list = [-0.16, -0.08, -0.04, -0.02, -0.01, 0.01, 0.02, 0.04, 0.08, 0.16]
    model = Net_2out()
    for label in range(10):
        output_dict, color4alpha= test4alpha(speed_list, label, model)
        target, max_len_out = test4speed(speed_list, label, output_dict, color4alpha)
        output_dict, target = plot_output(output_dict, target, speed_list, max_len_out, label)


