import tensorflow as tf
import tensorflow_compression as tfc
# from common.constrained_opt_lib import ConstraintManager
# from common.transforms import QuadAnalysis, QuadSynthesis, HyperAnalysis, HyperSynthesis
# from common.transforms import class_builder as transform_builder  # Just use the default class map.
from common.utils import ClassBuilder
# from common.latent_rvs_lib import UQLatentRV, LatentRVCollection
# from common.latent_rvs_utils import sga_schedule_at_step
# from common.immutabledict import immutabledict
from common.image_utils import mse_psnr
from common import data_lib
from common import tf_schedule as schedule
# from common import profile_utils
from common.custom_metrics import Metrics
from collections import OrderedDict
from ml_collections import ConfigDict
from absl import logging
from rdvae.nn_models import get_activation, make_mlp
import rdvae.tfc_utils as tfc_utils
from rdvae.rdvae_mlp import check_no_decoder, af_transform, softplus_inv_1
from common.utils import get_keras_optimizer
import tensorflow_probability as tfp

tfd = tfp.distributions
tfb = tfp.bijectors

# EMPTY_DICT = immutabledict()
EMPTY_DICT = {}

CODING_RANK = 1

HIGHER_LAMBDA_UNTIL = 0.2
HIGHER_LAMBDA_FACTOR = 1.


# Encapsulates model + optimizer.
# It's also possible to inherit from tf.keras.Model, although it might make the model construction
# code more cumbersome (from what I've seen, tf.keras.Model must define build() and call(),
# and these methods must be runnable in graph mode), but can simplify distributed training since
# it's already integrated into Model.fit,
# https://www.tensorflow.org/guide/distributed_training#use_tfdistributestrategy_with_keras_modelfit
# (whereas adding support for dstributed training to a custom training loop can take more work).
class Model(tf.Module):
  def __init__(self,
               rd_lambda,
               latent_dim,
               data_dim,
               distort_type,
               scale_lb=None,
               scheduled_num_steps=5000,
               # laplace_tail_mass=0,
               # offset_heuristic=True,
               # encoder_units=[],
               # decoder_units=[],
               # posterior_type='gaussian',
               # prior_type='deep',
               # ar_hidden_units=[],
               # ar_activation=None,
               transform_config=EMPTY_DICT,
               optimizer_config=EMPTY_DICT,
               dtype='float32'):
    super().__init__()
    self.latent_dim = latent_dim
    self.data_dim = data_dim
    self.distort_type = distort_type
    self.scale_lb = scale_lb
    self._scheduled_num_steps = scheduled_num_steps
    self._rd_lambda = rd_lambda
    self.dtype = dtype

    # Set up lr and optimizer
    self._optimizer_config = optimizer_config
    optimizer, lr_schedule_fn = self._get_optimizer(self._optimizer_config,
                                                    self._scheduled_num_steps)
    # self.compile(optimizer=optimizer)  # This sets self.optimizer and readies the model for training.
    self.optimizer = optimizer
    self._lr_schedule_fn = lr_schedule_fn

    # self._transform_config = transform_config
    # self._profile = profile
    self._init_transforms(transform_config)

  def _get_optimizer(self, optimizer_config, scheduled_num_steps):  # Note this overrides base.
    optimizer_config = dict(optimizer_config)  # Make a copy to avoid mutating the original.

    learning_rate = optimizer_config.pop("learning_rate", 1e-4)
    reduce_lr_after = optimizer_config.pop("reduce_lr_after", 0.8)
    reduce_lr_factor = optimizer_config.pop("reduce_lr_factor", 0.1)
    if "warmup_steps" in optimizer_config:
      warmup_steps = optimizer_config.pop("warmup_steps")
    else:
      warmup_until = optimizer_config.pop("warmup_until", 0.02)
      warmup_steps = int(warmup_until * scheduled_num_steps)
    warmup_start_step = optimizer_config.pop("warmup_start_step", 0)
    if "lr_drop_steps" in optimizer_config:
      # Specify a "multi-drop" lr schedule with explicit steps at which to drop lr.
      lr_schedule_fn = schedule.CustomDropCompressionSchedule(base_learning_rate=learning_rate,
                                                              total_num_steps=scheduled_num_steps,
                                                              drop_steps=optimizer_config.pop(
                                                                "lr_drop_steps"),
                                                              drop_factor=reduce_lr_factor,
                                                              warmup_steps=warmup_steps,
                                                              warmup_start_step=warmup_start_step,
                                                              )
    else:
      lr_schedule_fn = schedule.CompressionSchedule(base_learning_rate=learning_rate,
                                                    total_num_steps=scheduled_num_steps,
                                                    warmup_steps=warmup_steps,
                                                    warmup_start_step=warmup_start_step,
                                                    drop_after=reduce_lr_after,
                                                    drop_factor=reduce_lr_factor)
    optimizer_cls = get_keras_optimizer(optimizer_config.pop('name', 'adam'))
    optimizer = optimizer_cls(learning_rate=lr_schedule_fn, **optimizer_config)
    return optimizer, lr_schedule_fn

  def _init_transforms(self, transform_config=EMPTY_DICT):
    dtype = self.dtype
    # dtype = 'float32'
    # ar_activation = transform_config.ar_activation
    # ar_hidden_units = transform_config.ar_hidden_units
    # posterior_type = transform_config.posterior_type
    # ar_activation = get_activation(ar_activation, dtype=dtype)
    # iaf_stacks = transform_config.iaf_stacks
    self.__dict__.update(transform_config)
    posterior_type = self.posterior_type
    latent_dim = self.latent_dim
    ar_activation = get_activation(self.ar_activation, dtype)

    # borrowed from get_ntc_mlp_model
    # data_dim, = source.event_shape
    if posterior_type in ('gaussian', 'iaf'):
      encoder_output_dim = latent_dim * 2  # currently IAF uses a base Gaussian distribution conditioned on x
      if posterior_type == 'iaf':
        self._iaf_mades = [
          tfb.AutoregressiveNetwork(params=2, activation=ar_activation, hidden_units=self.ar_hidden_units) for
          _ in range(self.iaf_stacks)]
    else:
      encoder_output_dim = latent_dim

    # We always require an encoder network in order to produce the variational distribution Q(Y|X).
    # encoder_units = [] gives the minimal network.
    encoder = make_mlp(
      units=self.encoder_units + [encoder_output_dim],
      activation=get_activation(self.encoder_activation, dtype),
      name="encoder",
      input_shape=[self.data_dim],
      dtype=dtype,
    )

    # However, a decoder network is optional when dim(Y) == dim(X).
    # When decoder_units = [] (default), the code still uses a decoder network mapping from latent_dim to
    # data_dim. In order to specify "no decoder network at all", we follow the convention of setting decoder_units=[0]
    if check_no_decoder(self.decoder_units):
      decoder = None  # no decoder
      assert self.data_dim == latent_dim
      print('Not using decoder')
    else:  # decoder_units = [] allowed
      decoder = make_mlp(
        self.decoder_units + [self.data_dim],
        get_activation(self.decoder_activation, dtype),
        "decoder",
        [latent_dim],
        dtype,
      )

    self.encoder = encoder
    self.decoder = decoder

    # self.prior_type = prior_type
    self._prior = None
    if self.prior_type == "deep":
      self._prior = tfc_utils.MyDeepFactorized(
        batch_shape=[self.latent_dim], dtype=self.dtype)
    elif self.prior_type == 'std_gaussian':  # use 'gmm_1' for gaussian prior with learned mean/scale
      self._prior = tfd.MultivariateNormalDiag(loc=tf.zeros([self.latent_dim], dtype=self.dtype),
                                               scale_diag=tf.ones([self.latent_dim], dtype=self.dtype))
    elif self.prior_type == 'maf':
      # see https://www.tensorflow.org/probability/api_docs/python/tfp/bijectors/MaskedAutoregressiveFlow
      # and https://www.tensorflow.org/probability/api_docs/python/tfp/bijectors/AutoregressiveNetwork
      # maf = tfd.TransformedDistribution(
      #     distribution=tfd.MultivariateNormalDiag(loc=tf.zeros([self.latent_dim], dtype=self.dtype),
      #                                             scale_diag=tf.ones([self.latent_dim], dtype=self.dtype)),
      #     bijector=tfb.MaskedAutoregressiveFlow(
      #         shift_and_log_scale_fn=tfb.AutoregressiveNetwork(
      #             params=2, hidden_units=ar_hidden_units)))
      # self._prior = maf
      self._maf_mades = [
        tfb.AutoregressiveNetwork(params=2, activation=ar_activation, hidden_units=self.ar_hidden_units) for _
        in range(self.maf_stacks)]
      base_distribution = tfd.MultivariateNormalDiag(loc=tf.zeros([self.latent_dim], dtype=self.dtype),
                                                     scale_diag=tf.ones([self.latent_dim], dtype=self.dtype))
      self._prior = af_transform(base_distribution, self._maf_mades, permute=True, iaf=False)
    elif self.prior_type[:4] in ("gsm_", "gmm_", "lsm_", "lmm_"):  # mixture prior; specified like 'gmm_2'
      # This only implements a scalar mixture for each dimension, and the dimensions themselves are
      # still fully factorized just like tfc.DeepFactorized
      components = int(self.prior_type[4:])
      shape = (self.latent_dim, components)
      self.logits = tf.Variable(tf.random.normal(shape, dtype=self.dtype))
      self.log_scale = tf.Variable(
        tf.random.normal(shape, mean=2., dtype=self.dtype))
      if "s" in self.prior_type:  # scale mixture
        self.loc = 0.
      else:
        self.loc = tf.Variable(tf.random.normal(shape, dtype=self.dtype))
    else:
      raise ValueError(f"Unknown prior_type: '{self.prior_type}'.")

  def prior(self, conv_unoise=False):
    if self._prior is not None:
      prior = self._prior
    elif self.prior_type[:4] in ("gsm_", "gmm_", "lsm_", "lmm_"):
      cls = tfd.Normal if self.prior_type.startswith("g") else tfd.Logistic
      prior = tfd.MixtureSameFamily(
        mixture_distribution=tfd.Categorical(logits=self.logits),
        components_distribution=cls(
          # loc=self.loc, scale=tf.math.exp(self.log_scale)),
          loc=self.loc, scale=tf.math.softplus(self.log_scale)),
      )
    if conv_unoise:  # convolve with uniform noise for NTC compression model
      prior = tfc_utils.MyUniformNoiseAdapter(prior)
    return prior

  @property
  def global_step(self):
    return self.optimizer.iterations

  @property
  def _scheduled_lr(self):
    # This is just for logging/debugging purpose. Should equal self._lr_schedule_fun(self.global_step)
    # Also see https://github.com/google-research/google-research/blob/bb5e979a2d9389850fda7eb837ef9c8b8ba8244b/vct/src/models.py#672
    return self.optimizer._decayed_lr(tf.float32)

  @property
  def _scheduled_rd_lambda(self):
    """Returns the scheduled rd-lambda.
    Based on https://github.com/google-research/google-research/blob/master/vct/src/models.py#L400
    """
    _rd_lambda = tf.convert_to_tensor(self._rd_lambda)
    # if self._rd_lambda <= 0.01:  # Only do lambda warmup during model training.
    #   schedule_value = schedule.schedule_at_step(
    #     self.global_step,
    #     vals=[HIGHER_LAMBDA_FACTOR, 1.],
    #     boundaries=[int(self._scheduled_num_steps * HIGHER_LAMBDA_UNTIL)],
    #     interpolation=schedule.InterpolationType.CONSTANT
    #   )
    #   schedule_value = _rd_lambda * schedule_value
    # else:
    #   schedule_value = _rd_lambda
    schedule_value = _rd_lambda
    return schedule_value

  def encode_decode(self, x, training):
    if self.posterior_type in ('gaussian', 'iaf'):
      encoder_res = self.encoder(x)
      qy_loc = encoder_res[..., :self.latent_dim]
      qy_scale = tf.nn.softplus(encoder_res[..., self.latent_dim:] + softplus_inv_1)
      if self.scale_lb:
        qy_scale = qy_scale + self.scale_lb
      y_dist = tfd.MultivariateNormalDiag(loc=qy_loc, scale_diag=qy_scale, name="q_y")
      if self.posterior_type == 'iaf':
        # y_dist = tfd.TransformedDistribution(distribution=y_dist, bijector=self.iaf)
        y_dist = af_transform(y_dist, self._iaf_mades, permute=True, iaf=True)

      y_tilde = y_dist.sample()  # Y ~ Q(Y|X); batch_size by latent_dim; not using IWAE as it doesn't exactly correspond to our R-D objective here
      log_q_tilde = y_dist.log_prob(y_tilde)  # [batch_size]; should be 0 on avg for uniform distribution
      prior = self.prior(conv_unoise=False)
      # kls = tfd.kl_divergence(encoder_dist, self.prior_dist)  # only
      # for Gaussians
    elif self.posterior_type == 'uniform':
      encoder_res = self.encoder(x)
      prior = self.prior(conv_unoise=True)  # Balle VAE
      if not training:  # Hard quantization; do it the proper/fancy way (possibly with smart offset) using tfc.
        entropy_model = tfc.ContinuousBatchedEntropyModel(prior, coding_rank=CODING_RANK, compression=False)
        y_tilde = entropy_model.quantize(encoder_res)
      else:
        y_dist = tfd.Uniform(low=encoder_res - 0.5, high=encoder_res + 0.5, name="q_y")
        y_tilde = y_dist.sample()  # Y ~ Q(Y|X); batch_size by latent_dim
      log_q_tilde = 0.  # [batch_size]; should be 0 on avg for uniform distribution
    else:
      raise NotImplementedError(f'unknown posterior_type={self.posterior_type}')

    if self.prior_type == 'maf':
      log_prior = prior.log_prob(y_tilde)  # just [batch_size], one number per each x in the batch
    else:
      log_prior = tf.reduce_sum(prior.log_prob(y_tilde),
                                axis=-1)  # sum across latent_dim (since the prior is fully factorized)
    rates = log_q_tilde - log_prior

    if self.decoder:
      y_tilde = self.decoder(y_tilde)

    # Compute losses.
    axes_except_batch = list(range(1, len(x.shape)))
    if self.distort_type == 'mse':
      distortions = tf.reduce_mean(tf.math.squared_difference(x, y_tilde), axis=axes_except_batch)
    elif self.distort_type == 'sse':
      distortions = tf.reduce_sum(tf.math.squared_difference(x, y_tilde), axis=axes_except_batch)
    elif self.distort_type == 'half_sse':
      distortions = 0.5 * tf.reduce_sum(tf.math.squared_difference(x, y_tilde), axis=axes_except_batch)
    else:
      raise NotImplementedError
    distortion = tf.reduce_mean(distortions)  # Avg across batch.
    # if not self.nats:
    #   rates = (rates / tf.cast(tf.math.log(2.), self.dtype))  # convert to bits
    # if self.rpd:  # normalize by number of data dimension
    #   rate = tf.reduce_mean(rates) / float(self.data_dim)
    # else:
    #   rate = tf.reduce_mean(rates)
    rate = tf.reduce_mean(rates)
    loss = rate + self._rd_lambda * distortion
    record_dict = dict(loss=loss, rate=rate, distortion=distortion, scheduled_lr=self._scheduled_lr)

    metrics = Metrics.make()
    metrics.record_scalars(record_dict)
    # return dict(loss=loss, rate=rate, rates=rates, mse=distortion, y_tilde=y_tilde)
    return loss, metrics

    # hyper_latent_bpp = tf.reduce_mean(hyper_latent_bits) / num_pixels_per_image
    # latent_bpp = tf.reduce_mean(latent_bits) / num_pixels_per_image
    # tf.debugging.check_numerics(hyper_latent_bpp, "hyper_latent_bpp")
    # tf.debugging.check_numerics(latent_bpp, "latent_bpp")
    # bpp = hyper_latent_bpp + latent_bpp
    #
    # # Covert to [0, 255] to compute distortion.
    # image_batch = data_lib.floats_to_pixels(image_batch, training=training)
    # reconstruction = data_lib.floats_to_pixels(reconstruction, training=training)
    # batch_mse, batch_psnr = mse_psnr(image_batch, reconstruction)
    # distortion = tf.reduce_mean(batch_mse)
    # psnr = tf.reduce_mean(batch_psnr)
    #
    # record_dict = {}
    # # Compute MS-SSIM in validation mode.
    # if not training:
    #   max_pxl_val = 255.
    #   im_size = tf.shape(image_batch)[1:-1]
    #   # tf.image.ssim_multiscale seems to crash when input < 160x160
    #   if im_size[0] < 160 and im_size[1] < 160:
    #     # TODO: provide warning
    #     batch_msssim = tf.image.ssim(image_batch, reconstruction, max_val=max_pxl_val)
    #   else:
    #     batch_msssim = tf.image.ssim_multiscale(image_batch, reconstruction, max_val=max_pxl_val)
    #   batch_msssim_db = -10. * tf.math.log(1 - batch_msssim) / tf.math.log(10.)
    #   record_dict["msssim"] = tf.reduce_mean(batch_msssim)
    #   record_dict["msssim_db"] = tf.reduce_mean(batch_msssim_db)
    #
    # metrics = Metrics.make()
    # if self._bpp_constrained:
    #   # D + \beta * (R - R_c)
    #   rd_loss = distortion + self._bpp_manager.get_weighted_constraint_loss(bpp,
    #                                                                  self._scheduled_bpp_constraint)
    #   metrics.record_scalar('sched_bpp_constraint', self._scheduled_bpp_constraint)
    #   metrics.record_scalar('beta', self._bpp_manager.lamult)
    #   metrics.record_scalar('bpp_ema', self._bpp_manager.loss_ema)
    # else:
    #   # The rate-distortion Lagrangian.
    #   rd_loss = bpp + self._scheduled_rd_lambda * distortion
    #   metrics.record_scalar('sched_rd_lambda', self._scheduled_rd_lambda)
    #
    # if self.latent_config['uq']['method'] == 'sga':
    #   metrics.record_scalar('tau', self.latent_config['uq']['tau'])
    #
    # record_dict.update(
    #   dict(rd_loss=rd_loss, bpp=bpp, mse=distortion, psnr=psnr, scheduled_lr=self._scheduled_lr))
    # if self._profile:
    #   record_dict.update(timing_info)
    #
    # metrics.record_scalars(record_dict)
    # # Check for NaNs in the loss
    # tf.debugging.check_numerics(rd_loss, "rd_loss")
    #
    # metrics.record_image("reconstruction", reconstruction)
    # return rd_loss, metrics

  def train_step(self, image_batch):
    with tf.GradientTape() as tape:
      loss, metrics = self.encode_decode(image_batch, training=True)

    var_list = self.trainable_variables
    gradients = tape.gradient(loss, var_list)
    self.optimizer.apply_gradients(zip(gradients, var_list))
    return metrics

  # def test_step(self, image_batch):
  def validation_step(self, image_batch, training=False) -> Metrics:
    loss, metrics = self.encode_decode(image_batch, training=training)
    return metrics

  def evaluate(self, images) -> Metrics:
    """
    Used for getting final results.
    If a [B, H, W, 3] tensor is provided, will evaluate on individual image
    tensors ([1, H, W, 3]) in order. Otherwise, we assume a caller has passed in
    an iterable of images (although we do not verify that each image tensor has
    batch size = 1).
    :param images:
    :return:
    """
    if isinstance(images, tf.Tensor):
      batch_size = images.shape[0]
      images = tf.split(images, batch_size)
    else:
      images = images

    for img in images:
      loss, metrics = self.encode_decode(img, training=False)
      yield metrics

  def sample(self, num_samples):
    """
    Draw samples from the compression model.
    :param num_samples: int
    :return: a [num_samples, data_shape] tensor.
    """
    if self.prior_type == 'deep' and self.posterior_type == 'uniform':
      # sample from the discretized prior (as would be in actual entropy coding)
      prior = self.prior(
        conv_unoise=True)  # the prior is not actually convolved with unoise; this is just to get quantized samples
      samples = prior.sample(num_samples, quantized=True)
    else:
      samples = self.prior(conv_unoise=False).sample(num_samples)
    if self.decoder:
      samples = self.decoder(samples)
    return samples

  def create_images(self, num_samples=1000, coords_to_scatter=(0, 1), title=None):
    """
    Plot img metrics. Here we simply visualize samples from the model.
    :param coords_to_scatter:
    :param title:
    :return:
    """
    import matplotlib.pyplot as plt
    import plot_utils
    fig, ax = plt.subplots()
    i, j = coords_to_scatter
    model_samples = self.sample(num_samples)
    ax.scatter(model_samples[:, i], model_samples[:, j], marker='.', alpha=0.3, label=r'$\nu$')

    # if nu_w is not None:
    #   # Make sure the color isn't too faint, since the nu_w can be very close to 0.
    #   min_w_to_plot = 0.45
    #   nu_w_c = (1 - min_w_to_plot) * nu_w + min_w_to_plot
    #   ax.scatter(nu_x[:, i], nu_x[:, j], c=nu_w_c, cmap='Oranges', vmin=0, vmax=1, marker='x', label=r'$\nu$')
    # else:
    #   ax.scatter(nu_x[:, i], nu_x[:, j], marker='x', label=r'$\nu$')

    ax.legend()
    # if not hasattr(self, '_scatter_xlim'):
    #   self._scatter_xlim = ax.get_xlim()
    # if not hasattr(self, '_scatter_ylim'):
    #   self._scatter_ylim = ax.get_ylim()
    # ax.set_xlim(self._scatter_xlim)
    # ax.set_ylim(self._scatter_ylim)
    ax.set_aspect('equal')

    if title:
      ax.set_title(title)

    img = plot_utils.fig_to_np_arr(fig)
    plt.close(fig)

    return {'samples': img}
