from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from functools import partial
import warnings

from absl import flags
import tensorflow as tf

from graph_data import *
from utils import *

warnings.filterwarnings("ignore")

flags.DEFINE_string('ckpt_dir', '', '')
flags.DEFINE_integer('random_seed', 12345, '')
flags.DEFINE_integer('tf_random_seed', 601904901297, '')

# Input example params.
flags.DEFINE_string('dataset', 'grid_single', '')
flags.DEFINE_integer('node_embedding_dim', 14, 'Dimension of node embeddings.')
flags.DEFINE_string('node_features', 'gaussian',
                    'Can be laplacian, gaussian, or zero.')
flags.DEFINE_float('gaussian_scale', 0.3,
                   'Scale to use for random Gaussian features.')
flags.DEFINE_string('graph_dim', '10,10', '')

FLAGS = tf.app.flags.FLAGS

# Logging and print options.
np.set_printoptions(suppress=True, formatter={'float': '{: 0.3f}'.format})
tf.random.set_random_seed(FLAGS.tf_random_seed)
random.seed(FLAGS.random_seed)

NODE_FEATURES_MAP = {
    'laplacian':
    partial(add_laplacian_features, num_components=FLAGS.node_embedding_dim),
    'gaussian':
    partial(
        add_gaussian_noise_features,
        num_components=FLAGS.node_embedding_dim,
        scale=FLAGS.gaussian_scale),
    'zeros':
    partial(add_zero_features, num_components=FLAGS.node_embedding_dim),
    'positional':
    partial(
        add_positional_encoding_features,
        num_components=FLAGS.node_embedding_dim),
}
add_node_features_fn = NODE_FEATURES_MAP[FLAGS.node_features]

GRAPH_DIM = tuple(int(x) for x in FLAGS.graph_dim.split(','))
DATASET_MAP = {
    'grid_all':
    partial(get_grid_dataset_all, add_node_features_fn),
    'grid_split':
    partial(get_grid_dataset_split, add_node_features_fn),
    'grid_train_even_test_odd':
    partial(get_grid_dataset_train_even_test_odd, add_node_features_fn),
    'grid_train_odd_test_even':
    partial(get_grid_dataset_train_odd_test_even, add_node_features_fn),
    'grid_test_larger':
    partial(get_grid_dataset_all_test_larger, add_node_features_fn),
    'grid_test_smaller':
    partial(get_grid_dataset_all_test_smaller, add_node_features_fn),
    'grid_test_square':
    partial(get_grid_dataset_all_test_square, add_node_features_fn),
    'grid_single':
    partial(get_grid_dataset_single, GRAPH_DIM, add_node_features_fn)
}


def make_grid_example(graph_dim):
    g = nx.grid_2d_graph(*graph_dim, create_using=nx.DiGraph)
    g = convert_nx_repr(g, add_node_features_fn)
    ph = gn.utils_np.networkxs_to_graphs_tuple([g])
    return ph


dataset = DATASET_MAP[FLAGS.dataset]()
dataset = dataset.test_set
input_graph_phs = gn.utils_np.networkxs_to_graphs_tuple([dataset[0]])
input_graph_phs = make_grid_example(GRAPH_DIM)

feed_dict = {}
feed_dict["true_graph_phs/nodes:0"] = input_graph_phs.nodes
feed_dict["true_graph_phs/edges:0"] = input_graph_phs.edges
feed_dict["true_graph_phs/receivers:0"] = input_graph_phs.receivers
feed_dict["true_graph_phs/senders:0"] = input_graph_phs.senders
feed_dict["true_graph_phs/globals:0"] = input_graph_phs.globals
feed_dict["true_graph_phs/n_node:0"] = input_graph_phs.n_node
feed_dict["true_graph_phs/n_edge:0"] = input_graph_phs.n_edge
feed_dict["is_training:0"] = False

sess = reset_sess()
latest_checkpoint = tf.train.latest_checkpoint(FLAGS.ckpt_dir)
saver = tf.train.import_meta_graph("{}.meta".format(latest_checkpoint))
saver.restore(sess, latest_checkpoint)

values_map = {
    'pred_adj': tf.get_collection('pred_adj')[0],
    'num_incorrect': tf.get_collection('num_incorrect')[0]
}
values = sess.run(values_map, feed_dict=feed_dict)
pred_adj = values['pred_adj']
num_incorrect = values['num_incorrect']
print("num_incorrect {}".format(num_incorrect))
pred_adj *= 1 - np.eye(np.shape(pred_adj)[0])
pred_adj = np.where(pred_adj > 0.5, np.ones_like(pred_adj),
                    np.zeros_like(pred_adj))
graph = nx.from_numpy_matrix(pred_adj)
visualize_graph(graph, 'graph.png')
