import numpy as np
import torch
from stable_baselines3.common.distributions import DiagGaussianDistribution, sum_independent_dims

from action_masking.rlsampling.sets import Zonotope
from action_masking.rlsampling.torch_distributions import ZonotopeNormal


class ZonotopeDiagGaussian(DiagGaussianDistribution):
    def __init__(self, action_dim: int):
        super(ZonotopeDiagGaussian, self).__init__(action_dim)
        self.distribution: ZonotopeNormal = None
        self._normal_dist: torch.distributions.Normal = None  # Do we need that?

    def proba_distribution(
        self,
        mean_actions: torch.Tensor,
        log_std: torch.Tensor,
        generators: np.ndarray,
        centers: np.ndarray,
    ) -> "ZonotopeDiagGaussian":
        action_std = (torch.zeros_like(mean_actions) + log_std.exp()).clamp_min(torch.finfo(mean_actions.dtype).eps)
        self._normal_dist = torch.distributions.Normal(mean_actions, action_std)

        # Create a diagonal covariance matrix
        # covariance_matrix = torch.diag_embed(action_std**2, dim1=-2, dim2=-1)
        covariance_matrix = torch.diag_embed(action_std, dim1=-2, dim2=-1)

        # no batch
        if mean_actions.ndim == 1:
            zonotopes = [Zonotope(generators, centers)]
        # batch
        else:
            if centers.ndim == 1:
                centers = centers[np.newaxis, :]
                generators = generators[np.newaxis, :]
            zonotopes = [Zonotope(generators[i], centers[i]) for i in range(mean_actions.shape[0])]

        self.distribution = ZonotopeNormal(mean_actions, covariance_matrix, zonotopes)

        return self

    def sample(self) -> torch.Tensor:
        # TODO: super has rsample here. Is that a problem?
        sample = self.distribution.sample()
        if sample.ndim == 1:
            sample = sample[np.newaxis, :]
        return sample

    def log_prob(self, actions: torch.Tensor) -> torch.Tensor:
        # return self.distribution.log_prob(actions)
        log_prob = sum_independent_dims(self._normal_dist.log_prob(actions))
        # clamp min
        return torch.clamp(log_prob, min=-20.0)

    def mode(self) -> torch.Tensor:
        return self.distribution.mode

    def entropy(self) -> torch.Tensor:
        return self.distribution.entropy()

    def apply_masking(self, action_mask) -> None:
        # TODO: this method is required by the framework. So think about restructuring the code to actually make use of it.
        pass

    # Can use super!
    # def get_actions(self, deterministic: bool = False) -> torch.Tensor:
    #     return super().get_actions(deterministic)
