# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.


import common_utils


class CategoricalSampler:
    """Categorical sampler that return samples in one hot encoding
    """
    def __init__(self, greedy):
        """Initialization for Sampler."""
        self.greedy = greedy

    def clamp_prob(self, probs, min_prob):
        return probs.clamp(min_prob, 1 - min_prob)

    def sample(self, probs):
        """Sample an action from categorical distribution given logits

        Args:
            probs: probabilities returned by model forward
                   [batch, num_actions]
        """
        if self.greedy:
            samples = probs.max(1, keepdim=True)[1]
        else:
            samples = probs.multinomial(1)
        samples = common_utils.one_hot(samples, probs.size(1))
        return samples

    def get_log_prob(self, probs, samples):
        """Compute log prob of given samples

        Args:
            probs: probabilities returned by model forward
                   [batch, num_actions]
            samples: samples returned by this sampler given probs
                   [batch, num_actions]
        return: p: [batch]
        """
        common_utils.assert_eq(samples.size(), probs.size())
        p = (probs * samples.float()).sum(1)
        return p.log()

    def get_entropy(self, p):
        logp = p.log()
        ent = -(p * logp).sum(1)
        return ent
