#!/usr/bin/env python3
# -*- coding:utf-8 -*-
###
# File: saliency.py
# Created Date: Tuesday, October 6th 2020, 9:57:21 pm
# Author: Chirag Raman
#
# Copyright (c) 2020 Chirag Raman
###


import typing
from collections import OrderedDict
from typing import Tuple

import numpy as np
from torch.distributions import Distribution, Normal

from data.types import Seq2SeqSamples

MAP_TYPE = typing.OrderedDict[Tuple[int, int], float]

def sequence_saliency(inputs: Seq2SeqSamples, predictions: Normal,
                      prior: Distribution = None) -> Tuple[MAP_TYPE, MAP_TYPE, MAP_TYPE]:
    """ Compute the closed-form saliency of observed sequences

    Returns a tuple of the entropy map, the saliency map ignoring the prior,
    and the saliency map incorporating the prior
    """
    future_means = predictions.loc[0]
    future_scales = predictions.scale[0]
    seqs = zip(
        inputs.observed_start.detach().cpu(),
        inputs.offset.detach().cpu(),
        future_means.split(1, dim=1),
        future_scales.split(1, dim=1)
    )
    entropies = {}
    for (obs_start, offset, fut_mean, fut_std) in seqs:
        fut = Normal(loc=fut_mean.squeeze(1), scale=fut_std.squeeze(1))
        fut_entropy = fut.entropy().mean().detach().cpu()
        entropies[(int(obs_start), int(offset))] = float(fut_entropy)

    # Sort entropy values by observed start time
    entropies = OrderedDict(sorted(entropies.items(), key=lambda x: x[0][0]))

    # Compute gradients
    jacobian = np.gradient(list(entropies.values()))
    # Since this is a 1d problem, J'(x)J(x) is 1x1, therefore determinant
    # is the square of gradient itself.
    jacobian = np.square(jacobian)

    # The simpler saliency for a sequence x ignoring the prios is
    # S(x) := gradient_x^2
    saliency_map_noprior = OrderedDict(zip(entropies.keys(), jacobian))

    # The saliency for a sequence x is therefore
    # S(x) = -log prior(offset_x) + 0.5 log(gradient_x^2)
    saliency_map = OrderedDict()
    for (obs_start, offset), gradient in zip(entropies.keys(), jacobian):
        prior_log_prob = 0
        if prior is not None:
            prior_log_prob = prior.log_prob(obs_start)
        saliency = float(- prior_log_prob + (0.5 * np.log(gradient)))
        saliency_map[(obs_start, offset)] = saliency

    return entropies, saliency_map_noprior, saliency_map