from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from functools import partial
from itertools import product, permutations
import pickle
import random

from sklearn.manifold import spectral_embedding
from sklearn import datasets
import graph_rnn.create_graphs as cg
import graph_nets as gn
import networkx as nx
import numpy as np
import tensorflow as tf

from utils import *

FEATURES = gn.utils_np.GRAPH_NX_FEATURES_KEY
DICT_IND = 1


# Positional encoding features.
def get_angles(pos, i, d_model):
    angle_rates = 1 / np.power(10000, (2 * (i // 2)) / np.float32(d_model))
    return pos * angle_rates


def positional_encoding(position, d_model):
    angle_rads = get_angles(
        np.arange(position)[:, np.newaxis],
        np.arange(d_model)[np.newaxis, :], d_model)
    # apply sin to even indices in the array; 2i
    sines = np.sin(angle_rads[:, 0::2])
    # apply cos to odd indices in the array; 2i+1
    cosines = np.cos(angle_rads[:, 1::2])
    pos_encoding = np.concatenate([sines, cosines], axis=-1)
    pos_encoding = pos_encoding[np.newaxis, ...]
    return pos_encoding.astype(np.float32)


def add_positional_encoding_features(graph, num_components=5):
    num_nodes = nx.number_of_nodes(graph)
    pe = np.squeeze(positional_encoding(num_nodes, num_components))
    pe = np.random.permutation(pe)
    for i in range(num_nodes):
        graph.nodes[i][FEATURES] = pe[i]


# Laplacian features.
def add_laplacian_features(graph, num_components=5):
    adjacency_mat = nx.adjacency_matrix(graph)
    spectral_embeddings = spectral_embedding(
        adjacency_mat, n_components=num_components).astype(np.float32)
    nodes = graph.nodes(data=True)
    for (node, feature) in zip(nodes, spectral_embeddings):
        node[DICT_IND][FEATURES] = feature


# Zero features.
def add_zero_features(graph, num_components=5):
    for node in graph.nodes(data=True):
        node[DICT_IND][FEATURES] = np.zeros(num_components).astype(np.float32)


# Gaussian noise features.
def add_gaussian_noise_features(graph, num_components=5, scale=1.0):
    for node in graph.nodes(data=True):
        node[DICT_IND][FEATURES] = np.random.normal(
            scale=scale, size=num_components).astype(np.float32)


def add_adj_gaussian_features(graph, num_components=50, scale=1.0):
    n_node = nx.number_of_nodes(graph)
    adj = np.array(nx.to_numpy_matrix(graph))
    H = np.random.normal(scale=scale, size=(n_node, num_components))
    e, v = np.linalg.eig(adj)
    e_min = np.min(e)
    L = np.linalg.cholesky(adj + np.eye(n_node) *
                           (np.abs(np.min([e_min, 0])) + 1e-4))
    H = L.dot(H).astype(np.float32)
    ind = 0
    for node in graph.nodes(data=True):
        node[DICT_IND][FEATURES] = H[ind]
        ind += 1


# Convert nx grid representation to use a single number instead of a positional
# tuple to index nodes. Set edge and global features to be 0. For every node,
# add an edge to itself.
def convert_nx_repr(graph, add_node_features_fn):
    """Convert nx grid representation to use a single number instead of a
  positional tuple to index nodes. Add node features and empty edge features.
  """
    new_graph = nx.DiGraph(features=0)
    index_map = {}
    new_ind = 0
    for node in graph.nodes(data=True):
        index_map[node[0]] = new_ind
        new_graph.add_node(new_ind)
        new_graph.add_edge(new_ind, new_ind, features=0)
        new_ind += 1

    for edge in graph.edges(data=True):
        new_graph.add_edge(index_map[edge[0]], index_map[edge[1]], features=0)

    add_node_features_fn(new_graph)
    return new_graph


def generate_grid_graphs(dims_container):
    graphs = []
    for dims in dims_container:
        graphs.append(nx.grid_2d_graph(*dims, create_using=nx.DiGraph))
    return graphs


def preprocess_networkx_graphs(graphs, add_node_features_fn):
    to_ret = []
    for g in graphs:
        convert_graph = convert_nx_repr(g, add_node_features_fn)
        to_ret.append(convert_graph)
    return to_ret


class GraphDataset(object):
    def __init__(self, train_set, test_set, add_node_features_fn):
        self.train_set = train_set
        self.test_set = test_set
        self.index = 0
        self.add_node_features_fn = add_node_features_fn

    def get_next_train_batch(self, batch_size, graph_phs):
        batch = []
        for i in range(batch_size):
            if self.index == 0:
                random.shuffle(self.train_set)
            batch.append(self.train_set[self.index])
            self.index = (self.index + 1) % len(self.train_set)
        for g in batch:
            self.add_node_features_fn(g)
        feed_dict = {}
        feed_dict[graph_phs] = gn.utils_np.networkxs_to_graphs_tuple(batch)
        return feed_dict

    def get_test_set(self, graph_phs):
        feed_dict = {}
        feed_dict[graph_phs] = gn.utils_np.networkxs_to_graphs_tuple(
            random.choices(self.test_set, k=20))
        return feed_dict


def prod(r):
    return list(product(r, repeat=2))


def perm(r):
    return list(permutations(r, r=2))


def get_grid_dataset_small(add_features_fn, batch_size):
    num_graphs = int(batch_size / 2)
    graphs = preprocess_networkx_graphs(
        generate_grid_graphs([(2, 2)] * num_graphs + [(3, 3)] * num_graphs),
        add_features_fn)
    train_graphs = graphs
    test_graphs = graphs
    dataset = GraphDataset(train_graphs, test_graphs, add_features_fn)
    return dataset


def get_grid_dataset_single(graph_dim_tuple, add_features_fn, batch_size):
    graphs = preprocess_networkx_graphs(
        generate_grid_graphs([graph_dim_tuple] * batch_size), add_features_fn)
    train_graphs = graphs
    test_graphs = graphs
    dataset = GraphDataset(train_graphs, test_graphs, add_features_fn)
    return dataset


def get_grid_dataset_all(add_features_fn):
    graphs = preprocess_networkx_graphs(
        generate_grid_graphs(prod(range(10, 20))), add_features_fn)
    train_graphs = graphs
    test_graphs = graphs
    dataset = GraphDataset(train_graphs, test_graphs, add_features_fn)
    return dataset


def get_grid_dataset_train_even_test_odd(add_features_fn):
    train_graphs = preprocess_networkx_graphs(
        generate_grid_graphs(prod(range(10, 20, 2))), add_features_fn)
    test_graphs = preprocess_networkx_graphs(
        generate_grid_graphs(range(11, 20, 2)), add_features_fn)
    dataset = GraphDataset(train_graphs, test_graphs, add_features_fn)
    return dataset


def get_grid_dataset_train_odd_test_even(add_features_fn):
    test_graphs = preprocess_networkx_graphs(
        generate_grid_graphs(prod(range(10, 20, 2))), add_features_fn)
    train_graphs = preprocess_networkx_graphs(
        generate_grid_graphs(prod(range(11, 20, 2))), add_features_fn)
    dataset = GraphDataset(train_graphs, test_graphs, add_features_fn)
    return dataset


def get_grid_dataset_split(add_features_fn, train_percent=0.8):
    graphs = preprocess_networkx_graphs(
        generate_grid_graphs(prod(range(10, 20))), add_features_fn)
    random.shuffle(graphs)
    num_graphs = len(graphs)
    index = int(train_percent * num_graphs)
    train_graphs = graphs[:index]
    test_graphs = graphs[index:]
    dataset = GraphDataset(train_graphs, test_graphs, add_features_fn)
    return dataset


def get_grid_dataset_all_test_larger(add_features_fn):
    train_graphs = preprocess_networkx_graphs(
        generate_grid_graphs(prod(range(10, 20))), add_features_fn)
    test_graphs = preprocess_networkx_graphs(
        generate_grid_graphs(prod(range(20, 25))), add_features_fn)
    dataset = GraphDataset(train_graphs, test_graphs, add_features_fn)
    return dataset


def get_grid_dataset_all_test_smaller(add_features_fn):
    train_graphs = preprocess_networkx_graphs(
        generate_grid_graphs(prod(range(10, 20))), add_features_fn)
    test_graphs = preprocess_networkx_graphs(
        generate_grid_graphs(prod(range(5, 10))), add_features_fn)
    dataset = GraphDataset(train_graphs, test_graphs, add_features_fn)
    return dataset


def get_grid_dataset_all_test_square(add_features_fn):
    r = range(10, 20)
    train_graphs = preprocess_networkx_graphs(
        generate_grid_graphs(perm(range(10, 20))), add_features_fn)
    test_graphs = preprocess_networkx_graphs(
        generate_grid_graphs(list(zip(r, r))), add_features_fn)
    dataset = GraphDataset(train_graphs, test_graphs, add_features_fn)
    return dataset


# Grid graphs.
def get_grid_large():
    graphs = []
    for i in range(10, 20):
        for j in range(10, 20):
            graphs.append(nx.grid_2d_graph(i, j))
    return graphs


def get_grid_small():
    graphs = []
    for i in range(2, 5):
        for j in range(2, 6):
            graphs.append(nx.grid_2d_graph(i, j))
    return graphs


# Community graphs.


def community(c=2, k=20, p_path=0.1, p_edge=0.3):
    p = p_path
    path_count = max(int(np.ceil(p * k)), 1)
    G = nx.caveman_graph(c, k)
    # remove 50% edges
    p = 1 - p_edge
    for (u, v) in list(G.edges()):
        if np.random.rand() < p and ((u < k and v < k) or (u >= k and v >= k)):
            G.remove_edge(u, v)
    # add path_count links
    for i in range(path_count):
        u = np.random.randint(0, k)
        v = np.random.randint(k, k * 2)
        G.add_edge(u, v)
    G = max(nx.connected_component_subgraphs(G), key=len)
    return G.to_directed()


def get_community_large():
    graphs = []
    for i in range(2, 3):
        for j in range(30, 81):
            for k in range(10):
                graphs.append(community(i, j, p_edge=0.3))
    return graphs


def get_community_small():
    graphs = []
    for i in range(2, 3):
        for j in range(6, 11):
            for k in range(20):
                graphs.append(community(i, j, p_edge=0.8))  # default 0.8
    return graphs


def get_large_community_dataset_split(add_node_features_fn, train_percent=0.8):
    graphs = preprocess_networkx_graphs(get_community_large(),
                                        add_node_features_fn)
    for g in graphs:
        add_node_features_fn(g)
    random.shuffle(graphs)
    num_graphs = len(graphs)
    index = int(train_percent * num_graphs)
    train_graphs = graphs[:index]
    test_graphs = graphs[index:]
    dataset = GraphDataset(train_graphs, test_graphs, add_node_features_fn)
    return dataset


def get_small_community_dataset_split(add_node_features_fn, train_percent=0.8):
    graphs = preprocess_networkx_graphs(get_community_small(),
                                        add_node_features_fn)
    random.shuffle(graphs)
    num_graphs = len(graphs)
    index = int(train_percent * num_graphs)
    train_graphs = graphs[:index]
    test_graphs = graphs[index:]
    dataset = GraphDataset(train_graphs, test_graphs, add_node_features_fn)
    return dataset


def get_ego_large():
    graphs = cg.create('citeseer')
    for i in range(len(graphs)):
        graphs[i] = graphs[i].to_directed()
    return graphs


def get_ego_small():
    graphs = cg.create('citeseer_small')
    for i in range(len(graphs)):
        graphs[i] = graphs[i].to_directed()
    return graphs


def get_large_ego_dataset_split(add_node_features_fn, train_percent=0.8):
    graphs = preprocess_networkx_graphs(get_ego_large(), add_node_features_fn)
    random.shuffle(graphs)
    num_graphs = len(graphs)
    index = int(train_percent * num_graphs)
    train_graphs = graphs[:index]
    test_graphs = graphs[index:]
    dataset = GraphDataset(train_graphs, test_graphs, add_node_features_fn)
    return dataset


def get_small_ego_dataset_split(add_node_features_fn, train_percent=0.8):
    graphs = preprocess_networkx_graphs(get_ego_small(), add_node_features_fn)
    random.shuffle(graphs)
    num_graphs = len(graphs)
    index = int(train_percent * num_graphs)
    train_graphs = graphs[:index]
    test_graphs = graphs[index:]
    dataset = GraphDataset(train_graphs, test_graphs, add_node_features_fn)
    return dataset


def load_graph_rnn_dataset(filename, add_node_features_fn):
    graphs = pickle.load(open(filename, 'rb'))
    graphs_len = len(graphs)
    test_graphs = graphs[int(0.8 * graphs_len):]
    train_graphs = graphs[0:int(0.8 * graphs_len)]

    test_graphs = [g.to_directed() for g in test_graphs]
    test_graphs = preprocess_networkx_graphs(test_graphs, add_node_features_fn)
    train_graphs = [g.to_directed() for g in train_graphs]
    train_graphs = preprocess_networkx_graphs(train_graphs,
                                              add_node_features_fn)
    dataset = GraphDataset(train_graphs, test_graphs, add_node_features_fn)
    return dataset


class GrevnetGraphDataset(object):
    def __init__(self, train_set, test_set, add_node_features_fn):
        self.train_set = train_set
        self.test_set = test_set
        self.index = 0
        self.add_node_features_fn = add_node_features_fn

    def get_next_train_batch(self, batch_size):
        batch = []
        for i in range(batch_size):
            if self.index == 0:
                random.shuffle(self.train_set)
            batch.append(self.train_set[self.index])
            self.index = (self.index + 1) % len(self.train_set)
        for g in batch:
            self.add_node_features_fn(g)
        return gn.utils_np.networkxs_to_graphs_tuple(batch)

    def get_test_set(self, graph_phs):
        feed_dict = {}
        feed_dict[graph_phs] = gn.utils_np.networkxs_to_graphs_tuple(
            random.choices(self.test_set, k=20))
        return feed_dict


def load_grevnet_graph_rnn_dataset(filename, add_node_features_fn):
    graphs = pickle.load(open(filename, 'rb'))
    graphs_len = len(graphs)
    test_graphs = graphs[int(0.8 * graphs_len):]
    train_graphs = graphs[0:int(0.8 * graphs_len)]

    test_graphs = [g.to_directed() for g in test_graphs]
    test_graphs = preprocess_networkx_graphs(test_graphs, add_node_features_fn)
    train_graphs = [g.to_directed() for g in train_graphs]
    train_graphs = preprocess_networkx_graphs(train_graphs,
                                              add_node_features_fn)
    dataset = GrevnetGraphDataset(train_graphs, test_graphs,
                                  add_node_features_fn)
    return dataset
