'''Example of VAE on MNIST dataset using CNN

The VAE has a modular design. The encoder, decoder and VAE
are 3 models that share weights. After training the VAE model,
the encoder can be used to  generate latent vectors.
The decoder can be used to generate MNIST digits by sampling the
latent vector from a Gaussian distribution with mean=0 and std=1.

# Reference

[1] Kingma, Diederik P., and Max Welling.
"Auto-encoding variational bayes."
https://arxiv.org/abs/1312.6114
'''

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from keras.layers import Dense, Input
from keras.layers import Conv2D, Flatten, Lambda
from keras.layers import Reshape, Conv2DTranspose
from keras.models import Model
from keras.optimizers import RMSprop
from keras.losses import mse, binary_crossentropy
# from keras.utils import plot_model
from keras import backend as K
from keras.backend.tensorflow_backend import set_session
import tensorflow as tf

import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import argparse
import os
import imageio
from scipy.misc import imresize


# reparameterization trick
# instead of sampling from Q(z|X), sample eps = N(0,I)
# then z = z_mean + sqrt(var)*eps
def sampling(args):
    """Reparameterization trick by sampling fr an isotropic unit Gaussian.

    # Arguments:
        args (tensor): mean and log of variance of Q(z|X)

    # Returns:
        z (tensor): sampled latent vector
    """

    z_mean, z_log_var = args
    batch = K.shape(z_mean)[0]
    dim = K.int_shape(z_mean)[1]
    # by default, random_normal has mean=0 and std=1.0
    epsilon = K.random_normal(shape=(batch, dim))
    return z_mean + K.exp(0.5 * z_log_var) * epsilon


def plot_results(models,
                 x_test,
                 batch_size,
                 outpath="vae_mnist"):
    """Plots labels and MNIST digits as function of 2-dim latent vector

    # Arguments:
        models (tuple): encoder and decoder models
        data (tuple): test data and label
        batch_size (int): prediction batch size
        outpath (string): which model is using this function
    """

    encoder, decoder = models

    filename = os.path.join(outpath, "vae_mean.png")
    # display a 2D plot of the digit classes in the latent space
    z_mean, _, z_samples = encoder.predict(x_test,
                                   batch_size=batch_size)
    # get avg. distance within and out of class
    dist = []
    for i in range(16):
        dist.append([np.mean(np.linalg.norm(z_mean[i] - z_mean[:8], axis=1), axis=0),
                    np.mean(np.linalg.norm(z_mean[i] - z_mean[8:], axis=1), axis=0)])
    dist = np.array(dist)
    avg_dist = [np.mean(dist[:8,0]), np.mean(dist[:8,1]), np.mean(dist[8:,0]), np.mean(dist[8:,1])]
    print("\tgoal, not goal\ngoal     ", avg_dist[0], " ", avg_dist[1])
    print("not goal ", avg_dist[2], " ", avg_dist[3])

    plt.figure(figsize=(12, 10))
    plt.scatter(z_mean[:, 0], z_mean[:, 1], c=[1]*8+[0]*8)
    plt.colorbar()
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.savefig(filename)

    filename = os.path.join(outpath, "reconstructions.png")
    # display a 4x4 manifold of digits
    digit_size = 64
    # linearly spaced coordinates corresponding to the 2D plot
    # of digit classes in the latent space
    digits = []

    for i in range(x_test.shape[0]):
        z_sample = np.array([z_samples[i],])
        x_decoded = decoder.predict(z_sample)
        digits.append(x_decoded[0].reshape(digit_size, digit_size, 3))

    img = (np.concatenate(np.array(digits), axis=0)*255.).astype(np.uint8)
    imageio.imsave(filename, img)

def preprocess_img(img):
    '''resize and convert to [-1,1]'''
    img = imresize(img, (64, 64))
    return img/255.

def build_vae():
    # network parameters
    image_size = 64
    input_shape = (image_size, image_size, 3)
    kernel_size = 5
    filters = [16, 32, 32]
    stride = 3
    latent_dim = 4

    deconv_input_width=2
    deconv_input_height=2
    deconv_input_channels=32
    deconv_kernels = [5, 6, 6]
    deconv_channels = [32, 16]

    # build encoder model
    inputs = Input(shape=input_shape, name='encoder_input')
    x = inputs
    for i in range(3):
        x = Conv2D(filters=filters[i],
                   kernel_size=kernel_size,
                   activation='relu',
                   strides=stride,
                   padding='same')(x)

    # shape info needed to build decoder model
    shape = K.int_shape(x)

    # generate latent vector Q(z|X)
    x = Flatten()(x)
    z_mean = Dense(latent_dim, name='z_mean')(x)
    z_log_var = Dense(latent_dim, name='z_log_var')(x)

    # use reparameterization trick to push the sampling out as input
    # note that "output_shape" isn't necessary with the TensorFlow backend
    z = Lambda(sampling, output_shape=(latent_dim,), name='z')([z_mean, z_log_var])

    # instantiate encoder model
    encoder = Model(inputs, [z_mean, z_log_var, z], name='encoder')
    # build decoder model
    latent_inputs = Input(shape=(latent_dim,), name='z_sampling')
    x = Dense(deconv_input_width * deconv_input_height * deconv_input_channels, activation='relu')(latent_inputs)
    x = Reshape((deconv_input_width, deconv_input_height, deconv_input_channels))(x)

    for i in range(2):
        x = Conv2DTranspose(filters=deconv_channels[i],
                            kernel_size=deconv_kernels[i],
                            activation='relu',
                            strides=4,
                            padding='same')(x)

    outputs = Conv2DTranspose(filters=3,
                              kernel_size=deconv_kernels[-1],
                              activation='sigmoid',
                              strides=2,
                              padding='same',
                              name='decoder_output')(x)
    # instantiate decoder model
    decoder = Model(latent_inputs, outputs, name='decoder')

    # instantiate VAE model
    beta = 10.
    outputs = decoder(encoder(inputs)[2])
    vae = Model(inputs, outputs, name='vae')
    # VAE loss = mse_loss + kl_loss
    reconstruction_loss = K.mean(mse(inputs, outputs))

    reconstruction_loss *= image_size * image_size * 3
    kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)
    kl_loss = K.sum(kl_loss, axis=-1)
    kl_loss *= -0.5
    kl_loss *= beta
    vae_loss = K.mean(reconstruction_loss + kl_loss)
    vae.add_loss(vae_loss)
    vae.compile(optimizer=RMSprop(lr=1e-3))

    return encoder, decoder, vae, inputs, outputs

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    help_ = "Load h5 model trained weights"
    parser.add_argument("-w", "--weights", help=help_)
    args = parser.parse_args()

    # configure GPU
    os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
    os.environ['CUDA_VISIBLE_DEVICES'] = str(0)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    set_session(tf.Session(config=config))

    # load data
    ROOT = '.'
    NAME = 'MiniWorld-SimToReal1-v0'
    base_dir = '../data/{}/'.format(NAME)
    imgPaths = np.load(os.path.join(base_dir, 'filepaths.npy'))
    idxs = np.load(os.path.join(base_dir, 'randomized_idxs_success.npy'))
    imgPaths = imgPaths[idxs[:2008]]
    goalImgs = []
    for path in imgPaths:
        goalImgs.append(preprocess_img(imageio.imread(path)))
    goalImgs = np.array(goalImgs)

    imgPaths = os.listdir(os.path.join(base_dir, 'fail/fail/'.format(NAME)))
    imgPaths = np.array([os.path.join(base_dir, 'fail/fail/'.format(NAME), path) for path in imgPaths])
    idxs = np.load(os.path.join(base_dir, 'randomized_idxs_fail.npy'))
    imgPaths = imgPaths[idxs[:10008]]
    failImgs = []
    for path in imgPaths:
        failImgs.append(preprocess_img(imageio.imread(path)))
    failImgs = np.array(failImgs)

    x_train = np.concatenate((goalImgs[:-8], failImgs[:-8]), axis=0)
    x_test = np.concatenate((goalImgs[-8:], failImgs[-8:]), axis=0)


    # VAE model = encoder + decoder
    encoder, decoder, vae, inputs, outputs = build_vae()
    encoder.summary()
    decoder.summary()
    vae.summary()

    models = (encoder, decoder)
    batch_size = 128
    epochs = 101

    if args.weights:
        logdir = '/'.join(os.path.split(args.weights)[:-1])
        vae.load_weights(args.weights)
    else:
        import datetime
        datestamp = datetime.datetime.now().strftime('%Y-%m-%d|%H:%M:%S')
        logdir = os.path.join(ROOT, 'experiments',
            'halgan-{}'.format(NAME), 'betavae', datestamp)
        os.makedirs(logdir, exist_ok=True)
        os.makedirs(os.path.join(logdir, 'checkpoints'))
        # train the autoencoder
        vae.fit(x_train,
                epochs=epochs,
                batch_size=batch_size,
                validation_data=(x_test, None))
        vae.save_weights(os.path.join(logdir, 'checkpoints', 'params_vae.hdf5'))

    plot_results(models, x_test, batch_size=batch_size, outpath=logdir)

