from __future__ import print_function

import argparse
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
from CBoW_non_segment_sum import CBoW


def get_data(data_path):
    data = np.load(data_path)
    print(data.keys())
    adjacent_matrix_list = data['adjacent_matrix_list']
    distance_matrix_list = data['distance_matrix_list']
    bond_attribute_matrix_list = data['bond_attribute_matrix_list']
    node_attribute_matrix_list = data['node_attribute_matrix_list']
    label_name = data['label_name']
    # print('adjacent_matrix_list\t', adjacent_matrix_list.shape)
    # print('distance_matrix_list\t', distance_matrix_list.shape)
    # print('bond_attribute_matrix_list\t', bond_attribute_matrix_list.shape)
    # print('node_attribute_matrix_list\t', node_attribute_matrix_list.shape)
    # print('label_name\t', label_name.shape)

    return adjacent_matrix_list, distance_matrix_list, bond_attribute_matrix_list,\
           node_attribute_matrix_list, label_name


class GraphDataset(Dataset):
    def __init__(self, X_data):
        self.X_data = X_data
        print('data size: ', self.X_data.shape)

    def __len__(self):
        return len(self.X_data)

    def __getitem__(self, idx):
        x_data = self.X_data[idx]
        x_data = torch.from_numpy(x_data)
        return x_data


def get_path_representation(adjacent_matrix, tilde_node_attribute_matrix, random_dimension):
    node_valid_count = tilde_node_attribute_matrix.sum(axis=1)

    # 1-gram
    v1 = np.zeros((random_dimension))
    for i in range(max_atom_num):
        if node_valid_count[i] == 0:
            continue
        v1 += tilde_node_attribute_matrix[i]

    # 2-gram
    v2 = np.zeros((random_dimension))
    for i in range(max_atom_num):
        if node_valid_count[i] == 0:
            continue
        for j in range(i + 1, max_atom_num):
            if adjacent_matrix[i][j] == 0:
                continue
            v2 += tilde_node_attribute_matrix[i] * tilde_node_attribute_matrix[j]

    # 3-gram
    v3 = np.zeros((random_dimension))
    for i in range(max_atom_num):
        if node_valid_count[i] == 0:
            continue
        for j in range(i + 1, max_atom_num):
            if adjacent_matrix[i][j] == 0:
                continue
            for k in range(j + 1, max_atom_num):
                if adjacent_matrix[j][k] == 0:
                    continue
                v3 += tilde_node_attribute_matrix[i] * tilde_node_attribute_matrix[j] * tilde_node_attribute_matrix[k]

    # 4-gram
    v4 = np.zeros((random_dimension))
    for i in range(max_atom_num):
        if node_valid_count[i] == 0:
            continue
        for j in range(i + 1, max_atom_num):
            if adjacent_matrix[i][j] == 0:
                continue
            for k in range(j + 1, max_atom_num):
                if adjacent_matrix[j][k] == 0:
                    continue
                for l in range(k + 1, max_atom_num):
                    if adjacent_matrix[k][l] == 0:
                        continue
                    v4 += tilde_node_attribute_matrix[i] * tilde_node_attribute_matrix[j] * \
                          tilde_node_attribute_matrix[k] * tilde_node_attribute_matrix[l]

    # 5-gram
    v5 = np.zeros((random_dimension))
    for i in range(max_atom_num):
        if node_valid_count[i] == 0:
            continue
        for j in range(i + 1, max_atom_num):
            if adjacent_matrix[i][j] == 0:
                continue
            for k in range(j + 1, max_atom_num):
                if adjacent_matrix[j][k] == 0:
                    continue
                for l in range(k + 1, max_atom_num):
                    if adjacent_matrix[k][l] == 0:
                        continue
                    for m in range(l + 1, max_atom_num):
                        if adjacent_matrix[l][m] == 0:
                            continue
                        v5 += tilde_node_attribute_matrix[i] * tilde_node_attribute_matrix[j] * \
                              tilde_node_attribute_matrix[k] * tilde_node_attribute_matrix[l] * \
                              tilde_node_attribute_matrix[m]

    # 6-gram
    v6 = np.zeros((random_dimension))
    for i in range(max_atom_num):
        if node_valid_count[i] == 0:
            continue
        for j in range(i + 1, max_atom_num):
            if adjacent_matrix[i][j] == 0:
                continue
            for k in range(j + 1, max_atom_num):
                if adjacent_matrix[j][k] == 0:
                    continue
                for l in range(k + 1, max_atom_num):
                    if adjacent_matrix[k][l] == 0:
                        continue
                    for m in range(l + 1, max_atom_num):
                        if adjacent_matrix[l][m] == 0:
                            continue
                        for n in range(m + 1, max_atom_num):
                            if adjacent_matrix[m][n] == 0:
                                continue
                            v6 += tilde_node_attribute_matrix[i] * tilde_node_attribute_matrix[j] * \
                                  tilde_node_attribute_matrix[k] * tilde_node_attribute_matrix[l] * \
                                  tilde_node_attribute_matrix[m] * tilde_node_attribute_matrix[n]
    v = np.stack((v1, v2, v3, v4, v5, v6), axis=0)
    return v


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--mode', type=str, default='delaney')
    parser.add_argument('--running_index', type=int, default=0)
    parser.add_argument('--seed', type=int, default=123)
    args = parser.parse_args()
    mode = args.mode
    running_index = args.running_index
    seed = args.seed

    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.seed)

    random_dimension_list = [50, 100]
    feature_num = 42
    if mode in ['hiv'] or 'pcba' in mode or 'clintox' in mode:
        max_atom_num = 100
    else:
        max_atom_num = 55

    segmentation_list = [range(0, 10), range(10, 17), range(17, 24), range(24, 30), range(30, 36),
                         range(36, 38), range(38, 40), range(40, 42)]
    segmentation_list = np.array(segmentation_list)
    segmentation_num = len(segmentation_list)

    test_list = [running_index]
    train_list = filter(lambda x: x not in test_list, np.arange(5))
    print('training list: {}\ttest list: {}'.format(train_list, test_list))

    for random_dimension in random_dimension_list:
        model = CBoW(feature_num=feature_num, embedding_dim=random_dimension,
                     task_num=segmentation_num, task_size_list=segmentation_list)

        weight_file = '{}/{}/{}_CBoW_non_segment.pt'.format(mode, running_index, random_dimension)
        print('weight file is {}'.format(weight_file))
        model.load_state_dict(torch.load(weight_file))
        if torch.cuda.is_available():
            model.cuda()
        # print(model)
        model.eval()

        start_time = time.time()
        for i in range(5):
            data_path = '../datasets/{}/{}_graph.npz'.format(mode, i)
            adjacent_matrix_list, distance_matrix_list, bond_attribute_matrix_list, node_attribute_matrix_list, label_name = get_data(data_path)
            dataset = GraphDataset(X_data=node_attribute_matrix_list)
            dataloader = torch.utils.data.DataLoader(dataset,
                                                     batch_size=128,
                                                     shuffle=False)

            X_embed = []
            random_projected_list = []
            for batch_id, (x_data) in enumerate(dataloader):
                x_data = Variable(x_data).float()
                if torch.cuda.is_available():
                    x_data = x_data.cuda()

                x_embedded = model.embeddings(x_data)
                if torch.cuda.is_available():
                    x_embedded = x_embedded.cpu()
                X_embed.extend(x_embedded.data.numpy())

            embedded_node_attribute_matrix_list = np.array(X_embed)
            print('embedded embedded_node_attribute_matrix_list: ', embedded_node_attribute_matrix_list.shape)

            molecule_num = adjacent_matrix_list.shape[0]

            for index in range(molecule_num):
                adjacent_matrix = adjacent_matrix_list[index]
                distance_matrix = distance_matrix_list[index]
                tilde_node_attribute_matrix = embedded_node_attribute_matrix_list[index]

                v = get_path_representation(adjacent_matrix, tilde_node_attribute_matrix, random_dimension)
                random_projected_list.append(v)
            random_projected_list = np.array(random_projected_list)
            print('random_projected_list\t', random_projected_list.shape)

            out_file_path = './datasets/{}/{}/{}_grammed_cbow_{}_graph'.format(mode, running_index, i, random_dimension)
            np.savez_compressed(out_file_path,
                                adjacent_matrix_list=adjacent_matrix_list,
                                distance_matrix_list=distance_matrix_list,
                                node_attribute_matrix_list=embedded_node_attribute_matrix_list,
                                random_projected_list=random_projected_list,
                                label_name=label_name)

            print()
        end_time = time.time()
        processing_time = end_time - start_time
        print('For random dimension as {}, the processing time is {}.'.format(random_dimension, processing_time))
        print()
        print()
        print()
