from __future__ import print_function

import sys
sys.path.insert(0, '../src')
sys.path.insert(0, '../graph_methods')

import argparse
import pandas as pd
import csv
import numpy as np
import json
import sys
import random
from util import output_classification_result
from function import *
from graph_kernel import *
import os
import time


def get_data(data_path_list, test_data_path):
    adjacent_matrix_, node_attribute_matrix_, label_ = [], [], []
    count = 0

    for data_path in data_path_list:
        if data_path == test_data_path:
            print('testing on {}'.format(test_data_path))
            split_index = count
        adjacent_matrix_list, distance_matrix_list, bond_attribute_matrix_list, node_attribute_matrix_list, label_name = \
            extract_feature(data_path)
        adjacent_matrix_.append(adjacent_matrix_list)
        node_attribute_matrix_.append(node_attribute_matrix_list)
        label_.append(label_name)
        count += len(label_name)
        print(label_name.shape)

    adjacent_matrix_ = np.concatenate(adjacent_matrix_, axis=0)
    node_attribute_matrix_ = np.concatenate(node_attribute_matrix_, axis=0)
    node_attribute_matrix_ = node_attribute_matrix_.astype(str)
    label_ = np.concatenate(label_, axis=0)

    neo_node_attribute_matrix_ = [['' for _ in range(node_attribute_matrix_.shape[1])] for _ in range(node_attribute_matrix_.shape[0])]
    for i,graph in enumerate(node_attribute_matrix_):
        for j,edge in enumerate(node_attribute_matrix_[i]):
            node_attribute = ''.join(node_attribute_matrix_[i][j])
            neo_node_attribute_matrix_[i][j] = node_attribute[:]
    neo_node_attribute_matrix_ = np.array(neo_node_attribute_matrix_)
    print(neo_node_attribute_matrix_.shape)

    N = len(label_)
    return adjacent_matrix_, neo_node_attribute_matrix_, label_, split_index


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--target_name', action='store', dest='target_name',
                        type=str, required=False, default='MUV-466')
    parser.add_argument('--running_index', dest='running_index', action='store', type=int, default=0)
    parser.add_argument('--model', action='store', dest='model',
                        type=str, required=False, default='wl_kernel')
    parser.add_argument('--h', type=int, default=1)
    given_args = parser.parse_args()
    running_index = given_args.running_index
    target_name = given_args.target_name
    model = given_args.model
    h = given_args.h

    K = 5
    data_path_list = ['../datasets/muv/{}/{}_graph.npz'.format(target_name, i) for i in range(K)]
    test_data_path = '../datasets/muv/{}/{}_graph.npz'.format(target_name, running_index)
    data_path_list.remove(test_data_path)
    data_path_list.append(test_data_path)

    adjacent_matrix_list, node_attribute_matrix_list, label_name, split_index = get_data(data_path_list, test_data_path)
    label_name = reshape_data_into_2_dim(label_name)
    print('adjacent_matrix_list:\t\t', adjacent_matrix_list.shape)
    print('node_attribute_matrix_list:\t', node_attribute_matrix_list.shape)
    print('label_name:\t\t\t', label_name.shape)
    print()

    wl_kernel = weisleifer_lehman_graph_kernel(adjacent_matrix_list, node_attribute_matrix_list, h)
    kernel_train, kernel_test = wl_kernel[:split_index, :split_index], wl_kernel[split_index:, :split_index]
    y_train, y_test = label_name[:split_index], label_name[split_index:]
    print(wl_kernel.shape)
    print('train kernel size: {}\ttrain label size: {}'.format(kernel_train.shape, y_train.shape))
    print('test kernel size: {}\ttest label size: {}'.format(kernel_test.shape, y_test.shape))
    print()

    start_time = time.time()
    svm_clf = SVC(kernel='precomputed', probability=True)
    svm_clf.fit(kernel_train, y_train)
    end_time = time.time()
    print('Training time: {}'.format(end_time - start_time))

    y_pred_on_train = svm_clf.predict_proba(kernel_train)[:, 1]
    y_pred_on_test = svm_clf.predict_proba(kernel_test)[:, 1]
    end_time = time.time()
    print('Running time: {}'.format(end_time - start_time))

    output_classification_result(y_train=y_train, y_pred_on_train=y_pred_on_train,
                                 y_val=None, y_pred_on_val=None,
                                 y_test=y_test, y_pred_on_test=y_pred_on_test,
                                 EF_ratio_list=[])

    np.savez('output_on_test_{}'.format(h), y_test=y_test, y_pred=y_pred_on_test)
    os.rename('output_on_test_{}.npz'.format(h), '../output/{}/{}/{}_{}.npz'.format(running_index, model, target_name, h))
