from __future__ import print_function

import argparse
import pandas as pd
import csv
import numpy as np
import json

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader

import math
import sys
sys.path.insert(0, '../graph_methods')
sys.path.insert(0, '../src')
from function import read_merged_data, extract_feature_and_label
from dataloader import *
from util import *


def tensor_to_variable(x):
    if torch.cuda.is_available():
        x = x.cuda()
    return Variable(x.float())


def variable_to_numpy(x):
    if torch.cuda.is_available():
        x = x.cpu()
    x = x.data.numpy()
    return x


class DeepNNModel(nn.Module):
    def __init__(self, conf):
        super(DeepNNModel, self).__init__()

        self.input_dimension = conf['input_dimension']
        self.output_dimension = conf['output_dimension']
        self.layers = conf['layers']
        self.layers = [self.input_dimension] + self.layers + [self.output_dimension]
        self.layers_num = len(self.layers)

        self.fc_layer = nn.Sequential()
        for layer_idx, (in_features, out_features) in enumerate(zip(self.layers[:-1], self.layers[1:])):
            self.fc_layer.add_module('linear {}'.format(layer_idx), nn.Linear(in_features, out_features))
            if layer_idx < self.layers_num - 2:
                self.fc_layer.add_module('activation {}'.format(layer_idx), nn.ReLU())
                self.fc_layer.add_module('batchnorm {}'.format(layer_idx), nn.BatchNorm1d(out_features))

        return

    def forward(self, x):
        x = self.fc_layer(x)
        return x

    def loss_(self, y_predicted, y_actual, size_average=True):
        criterion = nn.MSELoss(size_average=size_average)
        loss = criterion(y_predicted, y_actual)
        return loss


class SingleRegression:
    def __init__(self, conf, **kwargs):
        self.conf = conf
        self.model = DeepNNModel(self.conf)
        print(self.model)

        torch.manual_seed(conf['random_seed'])
        if torch.cuda.is_available():
            self.model.cuda()
        self.model.apply(self.weights_init)

        self.lr = conf['learning_rate']
        self.weight_decay = conf['l2_weight_decay']
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay)

        self.batch_size = conf['batch_size']
        self.epoch = conf['epoch']

        train_dataset = kwargs['train_dataset']
        test_dataset = kwargs['test_dataset']
        self.train_dataloader = torch.utils.data.DataLoader(
            train_dataset, batch_size=self.batch_size, shuffle=True)
        self.test_dataloader = torch.utils.data.DataLoader(
            test_dataset, batch_size=self.batch_size, shuffle=False)

        return

    def weights_init(self, m):
        classname = m.__class__.__name__
        if 'Linear' in classname:
            m.weight.data.normal_(0.0, 0.02)
            # m.weight.clamp_(min=-2, max=2)
            m.bias.data.fill_(0)
        elif 'BatchNorm' in classname:
            m.weight.data.normal_(1.0, 0.02)
            m.bias.data.fill_(0)
        return

    def make_prediction(self, data_loader):
        self.model.eval()
        if data_loader is None:
            return None, None

        y_label_list = []
        y_pred_list = []
        for batch_id, (X_data, y_label) in enumerate(data_loader):
            X_data = tensor_to_variable(X_data)
            y_label = tensor_to_variable(y_label)
            y_pred = self.model(X_data)

            y_label_list.extend(variable_to_numpy(y_label))
            y_pred_list.extend(variable_to_numpy(y_pred))

        y_label_list = np.array(y_label_list)
        y_pred_list = np.array(y_pred_list)
        return y_label_list, y_pred_list

    def train(self, data_loader):
        self.model.train()
        total_loss = 0
        for batch_id, (X_data, y_label) in enumerate(data_loader):
            X_data = tensor_to_variable(X_data)
            y_label = tensor_to_variable(y_label)
            self.optimizer.zero_grad()

            y_pred = self.model(X_data)
            loss = self.model.loss_(y_predicted=y_pred, y_actual=y_label, size_average=False)
            total_loss += loss.data[0]
            loss.backward()
            self.optimizer.step()

        total_loss /= len(data_loader.dataset)
        return total_loss

    def train_and_predict(self, weight_file):
        for e in range(self.epoch):
            print('Epoch: {}'.format(e))
            train_loss = self.train(self.train_dataloader)
            print('Train loss: {}'.format(train_loss))

            if e % 10 == 0:
                y_train, y_pred_on_train = self.make_prediction(self.train_dataloader)
                if self.test_dataloader is not None:
                    y_test, y_pred_on_test = self.make_prediction(self.test_dataloader)
                else:
                    y_test, y_pred_on_test = None, None
                    
                output_regression_result_no_binary(y_train=y_train, y_pred_on_train=y_pred_on_train,
                                                   y_val=None, y_pred_on_val=None,
                                                   y_test=y_test, y_pred_on_test=y_pred_on_test)

        y_train, y_pred_on_train = self.make_prediction(self.train_dataloader)
        if self.test_dataloader is not None:
            y_test, y_pred_on_test = self.make_prediction(self.test_dataloader)
        else:
            y_test, y_pred_on_test = None, None

        output_regression_result_no_binary(y_train=y_train, y_pred_on_train=y_pred_on_train,
                                           y_val=None, y_pred_on_val=None,
                                           y_test=y_test, y_pred_on_test=y_pred_on_test)
        return

    def eval_with_existing(self, weight_file):
        y_train, y_pred_on_train = self.make_prediction(self.train_dataloader)
        if self.test_dataloader is not None:
            y_test, y_pred_on_test = self.make_prediction(self.test_dataloader)
        else:
            y_test, y_pred_on_test = None, None

        output_regression_result_no_binary(y_train=y_train, y_pred_on_train=y_pred_on_train,
                                           y_val=None, y_pred_on_val=None,
                                           y_test=y_test, y_pred_on_test=y_pred_on_test)
        return

    def save_model(self, model, weight_file):
        model.save_weights(weight_file)
        return

    def load_model(self, weight_file):
        model = self.setup_model()
        model.load_weights(weight_file)
        return model


def demo_single_regression():
    conf = {
        'input_dimension': 512,
        'output_dimension': 1,
        'layers': [2000],
        'batch_size': 100,
        'epoch': 100,
        'learning_rate': 0.001,
        'l2_weight_decay': 0.0001,

        'random_seed': 1337,
        'label_name_list': ['delaney']
    }
    label_name_list = conf['label_name_list']
    print('label_name_list ', label_name_list)

    test_index = slice(0, 1)
    train_index = slice(1, 5)
    train_file_list = file_list[train_index]
    test_file_list = file_list[test_index]
    print('train files ', train_file_list)
    print('test files ', test_file_list)

    train_dataset = FingerprintsDataSet(train_file_list,
                                        feature_name='Fingerprints',
                                        label_name_list=['delaney'])
    test_dataset = FingerprintsDataSet(test_file_list,
                                       feature_name='Fingerprints',
                                       label_name_list=['delaney'])
    print('done data preparation')

    print(conf['label_name_list'])
    kwargs = {'train_dataset': train_dataset, 'test_dataset': test_dataset}
    task = SingleRegression(conf=conf, **kwargs)
    task.train_and_predict(weight_file)

    return


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--weight_file', action='store', dest='weight_file', required=True)
    parser.add_argument('--mode', action='store', dest='mode', required=False, default='single_regression')
    given_args = parser.parse_args()
    weight_file = given_args.weight_file
    mode = given_args.mode

    if mode == 'single_regression':
        K = 5
        directory = '../datasets/delaney/{}.csv.gz'
        file_list = []
        for i in range(K):
            file_list.append(directory.format(i))

        demo_single_regression()