from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from functools import partial
import hashlib
import logging
import math
import matplotlib
matplotlib.use('agg')
import os
import pickle
import sys
import warnings

from absl import flags
import graph_nets as gn
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions

from grevnet_synthetic_data import *
from gnn import *
from graph_data import *
from loss import *
from utils import *

warnings.filterwarnings("ignore")

# Attention params.
flags.DEFINE_integer('attn_kv_dim', 20, '')
flags.DEFINE_integer('attn_output_dim', 20, '')
flags.DEFINE_integer('attn_num_heads', 8, '')
flags.DEFINE_bool('attn_multi_proj', True, '')
flags.DEFINE_integer('attn_multi_proj_dim', 160, '')
flags.DEFINE_bool('attn_concat', True, '')
flags.DEFINE_bool('attn_residual', False, '')
flags.DEFINE_bool('attn_layer_norm', False, '')

# Dataset params.
flags.DEFINE_string('dataset', '', '')

# Training params.
flags.DEFINE_integer('write_graphs_every_n_steps', 1000, '')
flags.DEFINE_integer('train_batch_size', 32, '')
flags.DEFINE_integer('random_seed', 12345, '')
flags.DEFINE_integer('tf_random_seed', 601904901297, '')
flags.DEFINE_string('logdir', 'test_grevnet_fixed_encoder',
                    'Where to write training files.')
flags.DEFINE_integer('num_train_iters', 100000, '')
flags.DEFINE_integer('log_every_n_steps', 50, '')
flags.DEFINE_integer('summary_every_n_steps', 25, '')
flags.DEFINE_integer('max_checkpoints_to_keep', 5, '')
flags.DEFINE_integer('save_every_n_steps', 5000, '')

flags.DEFINE_float('learning_rate', 5e-04, 'Learning rate for optimizer.')
flags.DEFINE_integer('learning_rate_decay_steps', 1000, '')
flags.DEFINE_float('learning_rate_decay_rate', 0.99, '')
flags.DEFINE_bool('learning_rate_decay_staircase', True, '')

flags.DEFINE_bool('use_fancy_lr_schedule', False, '')
flags.DEFINE_integer('learning_rate_rampup', 1000, '')
flags.DEFINE_integer('learning_rate_hold', 2000, '')
flags.DEFINE_integer('learning_rate_const_multiple', 3, '')

# GRevNet params.
flags.DEFINE_integer('num_coupling_layers', 10, '')
flags.DEFINE_bool('weight_sharing', False, '')

# GNN params.
flags.DEFINE_bool('residual', False, '')
flags.DEFINE_bool('use_batch_norm', True, '')
flags.DEFINE_bool('use_layer_norm', False, '')
flags.DEFINE_integer('num_layers', 3, 'Num of layers of MLP used in GNN.')
flags.DEFINE_integer('latent_dim', 2048, 'Latent dim of MLP used in GNN.')
flags.DEFINE_float('bias_init_stddev', 0.3,
                   'Used for initializing bias weights in GNN.')

# Node feature params.
flags.DEFINE_integer('node_embedding_dim', 14, 'Dimension of node embeddings.')

FLAGS = tf.app.flags.FLAGS
DATASET = FLAGS.dataset
logdir_prefix = os.environ.get('MLPATH')
if not logdir_prefix:
    logdir_prefix = '.'
LOGDIR = os.path.join(logdir_prefix, FLAGS.logdir)
os.makedirs(LOGDIR)
GRAPH_OUTPUT_DIR = os.path.join(LOGDIR, 'generated_graphs')
os.makedirs(GRAPH_OUTPUT_DIR)

# Logging and print options.
np.set_printoptions(suppress=True, formatter={'float': '{: 0.3f}'.format})
handlers = [logging.StreamHandler(sys.stdout)]
handlers.append(logging.FileHandler(os.path.join(LOGDIR, 'OUTPUT_LOG')))
logging.basicConfig(level=logging.INFO, handlers=handlers)
logger = logging.getLogger("logger")

tf.random.set_random_seed(FLAGS.tf_random_seed)
random.seed(FLAGS.random_seed)


def transform_example(n_node, nodes):
    globals = tf.zeros_like(n_node)
    senders, receivers = senders_receivers(n_node)
    senders.set_shape([None])
    receivers.set_shape([None])
    n_edge = tf.square(n_node)
    edges = tf.zeros_like(senders)
    return nodes, edges, globals, receivers, senders, n_node, n_edge


def dataset_generator(directory):
    files = [os.path.join(directory, x) for x in os.listdir(directory)]
    for f in files:
        data = pickle.load(open(f, 'rb'))
        n_node = data['n_node']
        nodes = data['node_features']
        for i in range(len(nodes)):
            yield n_node[i], nodes[i]


dataset = tf.data.Dataset.from_generator(
    partial(dataset_generator, DATASET), (tf.int32, tf.float32),
    (tf.TensorShape([FLAGS.train_batch_size]),
     tf.TensorShape([None, FLAGS.node_embedding_dim]))).map(
         transform_example, num_parallel_calls=10)
dataset = dataset.prefetch(1)
iterator = dataset.make_one_shot_iterator()

# Define GNN and output.
nodes, edges, globals, receivers, senders, n_node, n_edge = iterator.get_next()
graphs_tuple = gn.graphs.GraphsTuple(
    nodes=nodes,
    edges=edges,
    globals=globals,
    receivers=receivers,
    senders=senders,
    n_node=n_node,
    n_edge=n_edge)

make_mlp_fn = partial(
    make_mlp_model,
    FLAGS.latent_dim,
    FLAGS.node_embedding_dim / 2,
    FLAGS.num_layers,
    activation=tf.nn.relu,
    l2_regularizer_weight=0.000001,
    bias_init_stddev=FLAGS.bias_init_stddev)
self_attn_gnn = partial(
    self_attn_gnn,
    kv_dim=FLAGS.attn_kv_dim,
    output_dim=FLAGS.attn_output_dim,
    make_mlp_fn=make_mlp_fn,
    batch_size=FLAGS.train_batch_size,
    num_heads=FLAGS.attn_num_heads,
    multi_proj_dim=FLAGS.attn_multi_proj_dim,
    concat=FLAGS.attn_concat,
    residual=FLAGS.attn_residual,
    layer_norm=FLAGS.attn_layer_norm)
grevnet = GRevNet(
    self_attn_gnn,
    FLAGS.num_coupling_layers,
    FLAGS.node_embedding_dim,
    use_batch_norm=True,
    weight_sharing=FLAGS.weight_sharing)
grevnet_reverse_output, log_det_jacobian = grevnet(graphs_tuple, inverse=True)
grevnet_output_norm = tf.norm(grevnet_reverse_output.nodes, axis=1)
mvn = tfd.MultivariateNormalDiag(
    tf.zeros(FLAGS.node_embedding_dim), tf.ones(FLAGS.node_embedding_dim))

log_prob_zs = tf.reduce_sum(mvn.log_prob(grevnet_reverse_output.nodes))
log_prob_xs = log_prob_zs + log_det_jacobian
total_loss = -1 * log_prob_xs

# Optimizer.
global_step = tf.Variable(0, trainable=False, name='global_step')
decaying_learning_rate = tf.train.exponential_decay(
    learning_rate=FLAGS.learning_rate,
    global_step=global_step,
    decay_steps=FLAGS.learning_rate_decay_steps,
    decay_rate=FLAGS.learning_rate_decay_rate,
    staircase=FLAGS.learning_rate_decay_staircase)
learning_rate_placeholder = tf.placeholder(
    tf.float32, [], name='learning_rate')
learning_rate = learning_rate_placeholder if FLAGS.use_fancy_lr_schedule else decaying_learning_rate

optimizer = tf.train.AdamOptimizer(learning_rate)
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
    step_op = optimizer.minimize(total_loss, global_step=global_step)

# Sample model.
sample_n_node_placeholder = tf.placeholder(
    tf.int32, shape=[FLAGS.train_batch_size,], name="sample_n_node_placeholder")
sample_nodes = mvn.sample(sample_shape=(tf.reduce_sum(sample_n_node_placeholder, )))
sample_log_prob = mvn.log_prob(sample_nodes)
sample_nodes, sample_edges, sample_globals, sample_receivers, sample_senders, sample_n_node, sample_n_edge = transform_example(
    sample_n_node_placeholder, sample_nodes)
sample_graphs_tuple = gn.graphs.GraphsTuple(
    nodes=sample_nodes,
    edges=sample_edges,
    globals=sample_globals,
    receivers=sample_receivers,
    senders=sample_senders,
    n_node=sample_n_node,
   n_edge=sample_n_edge)

sample_grevnet_top = grevnet(sample_graphs_tuple, inverse=False)
sample_pred_adj = pred_adj(sample_grevnet_top, distance_fn=hacky_sigmoid_l2)

tf.summary.scalar('total_loss', total_loss)
tf.summary.scalar('log_prob_xs', log_prob_xs)
tf.summary.scalar('log_prob_zs', log_prob_zs)
tf.summary.scalar('log_det_jacobian', log_det_jacobian)

merged = tf.summary.merge_all()
sess = reset_sess()
train_writer = tf.summary.FileWriter(os.path.join(LOGDIR, 'train'), sess.graph)
eval_writer = tf.summary.FileWriter(os.path.join(LOGDIR, 'test'), sess.graph)

flags_map = tf.app.flags.FLAGS.flag_values_dict()
with open(os.path.join(LOGDIR, 'desc.txt'), 'w') as f:
    for (k, v) in flags_map.items():
        f.write("{}: {}\n".format(k, str(v)))

saver = tf.train.Saver(max_to_keep=FLAGS.max_checkpoints_to_keep)
train_values = {}
values_map = {
    "merge": merged,
    "step_op": step_op,
    "total_loss": total_loss,
    "log_prob_zs": log_prob_zs,
    "log_prob_xs": log_prob_xs,
    "log_det_jacobian": log_det_jacobian,
    "graphs_tuple": graphs_tuple,
}
samples_map = {
    "sample_pred_adj": sample_pred_adj,
    "sample_grevnet_top": sample_grevnet_top,
    "sample_log_prob": sample_log_prob,
    "sample_grevnet_top_nodes": sample_grevnet_top.nodes
}

for k, v in values_map.items():
    if k is not "graphs_tuple":
        tf.add_to_collection(k, v)
for k, v in samples_map.items():
    if k is not "sample_grevnet_top":
        tf.add_to_collection(k, v)

feed_dict = {}
for iteration in range(0, FLAGS.num_train_iters):
    if FLAGS.use_fancy_lr_schedule:
        feed_dict[learning_rate_placeholder] = get_learning_rate(
            iteration, FLAGS.learning_rate, FLAGS.learning_rate_rampup,
            FLAGS.learning_rate_hold, FLAGS.learning_rate_const_multiple)

    train_values = sess.run(values_map, feed_dict=feed_dict)
    if train_writer and (iteration % FLAGS.summary_every_n_steps == 0):
        train_writer.add_summary(train_values['merge'], iteration)
    if iteration % FLAGS.log_every_n_steps == 0:
        logger.info("*" * 100)
        logger.info("iteration num: {}".format(iteration))
        logger.info("total loss: {}".format(train_values["total_loss"]))
        logger.info("log prob zs: {}".format(train_values["log_prob_zs"]))
        logger.info("log det jacobian: {}".format(
            train_values["log_det_jacobian"]))

    # Save model.
    if iteration % FLAGS.save_every_n_steps == 0:
        saver.save(
            sess, os.path.join(LOGDIR, 'checkpoints'), global_step=global_step)

    if iteration % FLAGS.write_graphs_every_n_steps == 0:
        feed_dict = {sample_n_node_placeholder: train_values["graphs_tuple"].n_node}
        logger.info("*" * 100)
        logger.info("iteration num: {}".format(iteration))
        logger.info("writing graphs...")
        graphs = []
        values = sess.run(samples_map, feed_dict=feed_dict)
        n_node = values["sample_grevnet_top"].n_node
        pred_adj = values["sample_pred_adj"]
        adjacency = np.where(pred_adj > 0.5, np.ones_like(pred_adj),
                             np.zeros_like(pred_adj))
        n_node_cum = np.cumsum(n_node)
        start_ind = 0
        for i in range(FLAGS.train_batch_size):
            end_ind = n_node_cum[i]
            graph = adjacency[start_ind:end_ind, start_ind:end_ind]
            graph = nx.convert_matrix.from_numpy_matrix(graph)
            graphs.append(graph)
            start_ind = end_ind
        pickle.dump(
            graphs,
            open(
                os.path.join(GRAPH_OUTPUT_DIR,
                             "generated_graphs_iter_{}.p".format(iteration)),
                'wb'))
        logger.info("done writing graphs")
