from __future__ import print_function

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader

import sys
sys.path.insert(0, '../graph_methods')
from dataloader import *
from graph_util import num_atom_features, num_bond_features

from collections import OrderedDict
import logging
logger = logging.getLogger()
logger.setLevel(logging.INFO)


def rmse(X, Y):
    print('mse: {}'.format(np.mean((X-Y)**2)))
    return np.sqrt(np.mean((X - Y)**2))


def tensor_to_variable(x):
    if torch.cuda.is_available():
        x = x.cuda()
    return Variable(x.float())


def variable_to_numpy(x):
    if torch.cuda.is_available():
        x = x.cpu()
    x = x.data.numpy()
    return x


class Flatten(nn.Module):
    def forward(self, x):
        x = x.contiguous().view(x.size()[0], -1)
        return x


class GraphModel(nn.Module):
    def __init__(self, n_gram_num, random_projection_dimension, segmentation_num):
        super(GraphModel, self).__init__()
        self.n_gram_num = n_gram_num
        self.random_projection_dimension = random_projection_dimension
        self.segmentation_num = segmentation_num

        self.fc_layer = nn.Sequential(
            nn.Linear(self.random_projection_dimension * self.segmentation_num, 2048),
            nn.Sigmoid(),
            nn.Linear(2048, 1024),
            nn.Linear(1024, 512),
            nn.Linear(512, 256),
            nn.Linear(256, 30),
            Flatten(),
            nn.Linear(self.n_gram_num * 30, 1),
        )

    def forward(self, random_projected_matrix):
        x = self.fc_layer(random_projected_matrix)
        return x

    def loss_(self, y_predicted, y_actual, size_average=True):
        criterion = nn.MSELoss(size_average=size_average)
        loss = criterion(y_predicted, y_actual)
        return loss


def visualize(model):
    params = model.state_dict()
    for k, v in sorted(params.items()):
        print(k, v.shape)
    for name, param in model.named_parameters():
        print(name, '\t', param.requires_grad, '\t', param.data)
    return


def train(data_loader):
    graph_model.train()
    total_loss = 0
    for batch_id, (random_projected_matrix, y_label) in enumerate(data_loader):
        random_projected_matrix = tensor_to_variable(random_projected_matrix)
        y_label = tensor_to_variable(y_label)
        y_pred = graph_model(random_projected_matrix=random_projected_matrix)
        loss = graph_model.loss_(y_predicted=y_pred, y_actual=y_label, size_average=False)
        total_loss += loss.data[0]
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    total_loss /= len(data_loader.dataset)
    return total_loss


def make_predictions(data_loader):
    if data_loader is None:
        return None, None
    y_label_list = []
    y_pred_list = []
    for batch_id, (random_projected_matrix, y_label) in enumerate(data_loader):
        random_projected_matrix = tensor_to_variable(random_projected_matrix)
        y_label = tensor_to_variable(y_label)
        y_pred = graph_model(random_projected_matrix=random_projected_matrix)
        y_label_list.extend(variable_to_numpy(y_label))
        y_pred_list.extend(variable_to_numpy(y_pred))
    y_label_list = np.array(y_label_list)
    y_pred_list = np.array(y_pred_list)
    return y_label_list, y_pred_list


def test(train_dataloader=None, test_dataloader=None):
    graph_model.eval()
    y_train, y_pred_on_train = make_predictions(train_dataloader)
    rmse_train = rmse(y_pred_on_train, y_train)
    print('RMSE on train set: {}'.format(rmse_train))
    if test_dataloader is not None:
        y_test, y_pred_on_test = make_predictions(test_dataloader)
        rmse_train = rmse(y_pred_on_test, y_test)
        print('RMSE on train set: {}'.format(rmse_train))
    return


def save_model(weight_path):
    print('Saving weight path:\t', weight_path)
    with open(weight_path, 'wb') as f_:
        torch.save(graph_model, f_)


def load_best_model(weight_path):
    with open(weight_path, 'rb') as f_:
        graph_model = torch.load(f_)
    return graph_model


if __name__ == '__main__':
    import time
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--epoch', action='store', dest='epoch',
                        type=int, required=False, default=300)
    parser.add_argument('--batch_size', action='store', dest='batch_size',
                        type=int, required=False, default=128)
    parser.add_argument('--learning_rate', action='store', dest='learning_rate',
                        type=float, required=False, default=1e-3)
    parser.add_argument('--min_learning_rate', action='store', dest='min_learning_rate',
                        type=float, required=False, default=1e-5)
    parser.add_argument('--seed', action='store', dest='seed',
                        type=int, required=False, default=123)
    parser.add_argument('--target_name', action='store', dest='target_name',
                        type=str, required=False, default='NR-AhR')
    parser.add_argument('--n_gram_num', dest='n_gram_num', type=int,
                        action='store', required=False, default=6)
    parser.add_argument('--random_projection_dimension', dest='random_projection_dimension', type=int,
                        action='store', required=False, default=50)
    parser.add_argument('--segmentation_num', dest='segmentation_num', type=int,
                        action='store', required=False, default=8)
    parser.add_argument('--is_sum_up_feature_segment', action='store_true', dest='is_sum_up_feature_segment')
    parser.set_defaults(is_sum_up_feature_segment=False)
    given_args = parser.parse_args()

    n_gram_num = given_args.n_gram_num
    random_projection_dimension = given_args.random_projection_dimension
    segmentation_num = given_args.segmentation_num
    is_sum_up_feature_segment = given_args.is_sum_up_feature_segment
    if is_sum_up_feature_segment:
        segmentation_num = 1

    K = 5
    target_name = given_args.target_name
    directory = '../datasets/malaria/{}_grammed_random_{}_graph.npz'
    file_list = []
    for i in range(K):
        file_list.append(directory.format(i, random_projection_dimension))

    EPOCHS = given_args.epoch
    BATCH = given_args.batch_size
    MAX_ATOM_NUM = 55
    ATOM_FEATURE_DIM = num_atom_features()
    BOND_FEATURE_DIM = num_bond_features()
    torch.manual_seed(given_args.seed)

    graph_model = GraphModel(n_gram_num=n_gram_num,
                             random_projection_dimension=random_projection_dimension,
                             segmentation_num=segmentation_num)
    if torch.cuda.is_available():
        graph_model.cuda()
    # graph_model.apply(weights_init)
    # visualize(graph_model)
    print(graph_model)

    train_graph_matrix_file = file_list[:4]
    test_graph_matrix_file = file_list[4]

    train_dataset = GraphDataset_N_Gram_Random_Projection(train_graph_matrix_file,
                                                          n_gram_num=n_gram_num,
                                                          is_sum_up_feature_segment=is_sum_up_feature_segment)
    test_dataset = GraphDataset_N_Gram_Random_Projection(test_graph_matrix_file,
                                                         n_gram_num=n_gram_num,
                                                         is_sum_up_feature_segment=is_sum_up_feature_segment)

    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH, shuffle=True)
    test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH, shuffle=False)

    # optimizer = optim.SGD(graph_model.parameters(), lr=1e-2, momentum=0.9, weight_decay=1e-5)
    optimizer = optim.Adam(graph_model.parameters(), lr=given_args.learning_rate, weight_decay=5e-3)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.9, patience=10,
                                                     min_lr=given_args.min_learning_rate, verbose=True)

    for epoch in range(EPOCHS):
        print('Epoch: {}'.format(epoch))

        train_start_time = time.time()
        train_loss = train(train_dataloader)
        scheduler.step(train_loss)
        train_end_time = time.time()
        # print('Train time: {:.3f}s. Train loss is {}.'.format(train_end_time - train_start_time, train_loss))

        if epoch % 10 == 0:
            test_start_time = time.time()
            test(train_dataloader=train_dataloader, test_dataloader=test_dataloader)
            test_end_time = time.time()
            print('Test time: {:.3f}s.'.format(test_end_time - test_start_time))
            print()

    test_start_time = time.time()
    test(train_dataloader=train_dataloader, test_dataloader=test_dataloader)
    test_end_time = time.time()
    print('Test time: {:.3f}s.'.format(test_end_time - test_start_time))
    print()