from datetime import datetime
import pdb
import copy

import numpy as np

import matplotlib.pyplot as plt
import torch
import torch.optim as optim

# Net_2out is a network in network.py which outputs 2 values:x,y
from network import Net_2out
# loadDigit is a function in loadData.py which loads data from ScaledDigits.mat
from loadData import loadDigits
from pack import CANN


class train_CANN_loader():
    # load data for training
    def __init__(self, batch_size=10):
        self.digitDict = loadDigits()
        self.batch_size = batch_size


    def __len__(self):
        return len(self.labelData)

    def digit1_gendata(self, digit_label=5):
        # digit_label is one digit in 0-9
        self.labelData = (np.ones(self.batch_size)*digit_label).astype(int)

        # generate input and target data
        self.inputData = dict()
        self.targetData = dict()
        for i in range(self.batch_size):
            input, dis = CANN(alpha=0.04)
            self.inputData[i] = input
            self.targetData[i] = np.zeros([len(input), 2])
            dis_max = dis.max()
            dis_min = dis.min()
            for t_step in range(len(input)):
                dis_t = dis[t_step]
                target_t = (dis_max-dis_t)/(dis_max-dis_min)*1000
                target_t_int = np.floor(target_t).astype(int)
                if target_t_int >= 999:
                    self.targetData[i][t_step, 0] = self.digitDict[self.labelData[i]][999, 0]
                    self.targetData[i][t_step, 1] = self.digitDict[self.labelData[i]][999, 1]
                else:
                    self.targetData[i][t_step, 0] = self.digitDict[self.labelData[i]][target_t_int, 0]*(target_t_int+1-target_t)+\
                        self.digitDict[self.labelData[i]][target_t_int+1, 0]*(target_t-target_t_int)
                    self.targetData[i][t_step, 1] = self.digitDict[self.labelData[i]][target_t_int, 1]*(target_t_int+1-target_t)+\
                        self.digitDict[self.labelData[i]][target_t_int+1, 1]*(target_t-target_t_int)

        return self.inputData, self.labelData, self.targetData
    

def test_CANNloader(num_CANN=512, digit_label=999):
    # load real CANN data for testing
    # num_CANN: number of neurons in CANN
    # digit_label: the label of digit, 0-9

    digitDict = loadDigits()
    if digit_label!=999:
        labelData = np.ones(1)*digit_label
    else:
        # generate label data,0-9
        labelData = np.arange(0, 10)

    batch_size = len(labelData)

    # generate input data
    inputData, dis = CANN(alpha=0.04)

    # generate target data
    # the bump of CANN is not a unify movement, so the target data is not unify
    # mapping from the center of CANN (512) to the digital trajectory (1000)

    targetData = np.zeros([batch_size,len(dis), 2])

    for i in range(batch_size):
        for j in range(len(dis)):
            dis_max=dis.max()
            dis_min=dis.min()
            digit_index=(dis_max-dis[j])/(dis_max-dis_min)*1000
            digit_index_int = np.floor(digit_index).astype(int)
            if digit_index>=999:
                targetData[i,j,0]=digitDict[labelData[i]][999,0]
                targetData[i,j,1]=digitDict[labelData[i]][999,1]
            else:
                targetData[i,j,0]=digitDict[labelData[i]][digit_index_int,0]*(digit_index_int+1-digit_index)+\
                    digitDict[labelData[i]][digit_index_int+1,0]*(digit_index-digit_index_int)
                targetData[i,j,1]=digitDict[labelData[i]][digit_index_int,1]*(digit_index_int+1-digit_index)+\
                    digitDict[labelData[i]][digit_index_int+1,1]*(digit_index-digit_index_int)
    return inputData, labelData, targetData


def testloader(num_CANN=512, num_Step=1000, digit_label=999):
    # generate the BUMP for testing digital trajectory
    # num_CANN: number of neurons in CANN
    # num_Step: number of time steps
    # digit_label: the label of digit, 0-9

    digitDict = loadDigits()
    if digit_label!=999:
        labelData = (np.ones(1)*digit_label).astype(int)
    else:
        # generate label data,0-9
        labelData = np.arange(0, 10)

    batch_size = len(labelData)

    # generate input data
    inputData = np.zeros([num_Step, num_CANN])
    # set initial and traget position of CANN
    traget_z = 0
    initial_z = traget_z+int(num_CANN/5)*np.pi*2/num_CANN
    dz_per_timeStep = (traget_z-initial_z)/num_Step
    a = 0.5
    x = np.linspace(-np.pi, np.pi, num_CANN)
    # change z position in each time step and generate input r by Gaussian function
    for i in range(num_Step):
        z = initial_z+(i+1)*dz_per_timeStep
        r = np.exp(-(x-z)**2/(2*a**2))*0.025
        inputData[i, :] = r

    # generate target data
    targetData = np.zeros([batch_size, num_Step, 2])
    for i in range(batch_size):

        targetData[i, :, 0] = digitDict[labelData[i]][:, 0]
        targetData[i, :, 1] = digitDict[labelData[i]][:, 1]
    return inputData, labelData, targetData


def train_model(model, train_loader, test_CANN_loader, test_loader,batch_size=10,
                num_epochs=400, learning_rate=0.001, digit_label=999):
    # train the model
    # model: the network
    # train_loader: the data loader for training
    # test_CANN_loader: the data loader for testing CANN
    # test_loader: the data loader for testing digital trajectory
    # batch_size: the number of samples in each batch
    # num_epochs: the number of epochs
    # learning_rate: the learning rate
    # digit_label: the label of digit, 0-9

    loss_total = np.zeros(num_epochs)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    criterion = torch.nn.MSELoss()
    digit_dic = loadDigits()
    min_loss = 100
    for epoch in range(num_epochs):
        if digit_label!=999:
            input, label_data, target_data = train_loader.digit1_gendata(
                digit_label=digit_label)
        else:
            input, label_data, target_data = train_loader.gendata(epoch)

        # train and print loss
        for batch_idx in range(batch_size):
            label = label_data[0]
            target = target_data[0]
            input_noise = input[0]+ np.random.normal(0, 0.01, input[0].shape)
            loss = train(input_noise, label, target,
                         model, optimizer, criterion)
            if (batch_idx+1) % 10 == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, (batch_idx+1), batch_size,
                    100. * (batch_idx+1) / batch_size, loss))
        label = label_data[0]
        if loss < min_loss:
            best_model = copy.deepcopy(model)
        loss_total[epoch] = loss


        # test and print loss
        # plot the output of CANN and the target trajectory
        if epoch % 20 == 0:
            test_out = test(test_loader, model, criterion)
            plt.subplot(1, 3, 1)
            plt.plot(test_out[0, :, 0], test_out[0, :, 1])
            plt.title("label:{}".format(label))
            test_out = test(test_CANN_loader, model, criterion)
            plt.subplot(1, 3, 2)
            plt.plot(test_out[0, :, 0], test_out[0, :, 1])
            plt.title("label:{}".format(label))
            plt.subplot(1, 3, 3)
            plt.plot(digit_dic[label][:, 0], digit_dic[label][:, 1])
            plt.title("label:{}".format(label))
            plt.show(block=False)
            plt.pause(1)

        # save the best model
        if (epoch+1) % 200 == 0:
            torch.save(best_model.state_dict(), "1CANNmodel_{}_label_{}.pth".format(epoch+1, label))
            test_out = test(test_loader, best_model, criterion)
            plt.subplot(1, 3, 1)
            plt.plot(test_out[0, :, 0], test_out[0, :, 1])
            plt.title("label:{}".format(label))
            test_out = test(test_CANN_loader, best_model, criterion)
            plt.subplot(1, 3, 2)
            plt.plot(test_out[0, :, 0], test_out[0, :, 1])
            plt.title("label:{}".format(label))
            plt.subplot(1, 3, 3)
            plt.plot(digit_dic[label][:, 0], digit_dic[label][:, 1])
            plt.title("label:{}".format(label))
            plt.savefig("CANNmodel_{}_label_{}.png".format(epoch+1, label))

        # plot loss curve
        if (epoch+1) % 400 == 0:
            plt.figure()
            plt.plot(np.arange(num_epochs)+1, loss_total)
            plt.title("Loss")
            plt.xlabel("Epoch")
            plt.savefig("loss_{}_label_{}.pdf".format(epoch+1, label))


def train(input, label, target, model, optimizer, criterion):
    # train the model for one digit
    # input: the input of CANN
    # label: the label of digit, 0-9
    # target: the target trajectory
    # model: the network
    # optimizer: the optimizer
    # criterion: the loss function

    timestep = input.shape[0]  # 1000
    model.train()
    loss = 0
    for i in range(timestep):
        input_i = input[i]
        target_i = target[i, :]
        input_i = torch.from_numpy(input_i).float()
        target_i = torch.from_numpy(target_i).float()
        optimizer.zero_grad()
        output_i = model(input_i)
        loss = loss+criterion(output_i, target_i)
    loss.backward()
    optimizer.step()
    return loss.item()


def test(test_loader, model, criterion):
    # test the model
    # test_loader: the data loader for testing
    # model: the network
    # criterion: the loss function
    
    input, label_data, traget_data = test_loader
    output = np.zeros(traget_data.shape)
    timestep = input.shape[0]  # 1000
    model.eval()
    loss = 0
    with torch.no_grad():
        for digit_index in range(len(label_data)):
            label = label_data[digit_index]
            target = traget_data[digit_index, :, :]
            for i in range(timestep):
                input_i = input[i]
                target_i = target[i, :]
                input_i = torch.from_numpy(input_i).float()
                target_i = torch.from_numpy(target_i).float()
                output_i = model(input_i)
                loss += criterion(output_i, target_i)
                output[digit_index, i, :] = output_i
    loss /= (timestep * len(label_data))
    print('Test set: Average loss: {:.4f}'.format(loss))
    return output


def main():
    # set peremeters
    num_CANN = 512
    batch_size = 10
    for digit_label in range(10):
        model = Net_2out(num_CANN=num_CANN)
        print(model)

        # load data
        test_CANN_loader = test_CANNloader(
            num_CANN=num_CANN, digit_label=5)
        test_loader = testloader(
            num_CANN=num_CANN, digit_label=5)
        
        train_loader = train_CANN_loader(batch_size=1)

        train_model(model, train_loader, test_CANN_loader, test_loader,
                    batch_size=batch_size,digit_label=digit_label)


if __name__ == '__main__':
    main()
