import numpy as np
import torch
from stable_baselines3.common.distributions import DiagGaussianDistribution


class GeneratorDiagGaussian(DiagGaussianDistribution):
    def __init__(self, action_dim: int):
        super(GeneratorDiagGaussian, self).__init__(action_dim)
        # Distribution in generator space
        self.generator_dist: torch.distributions.Normal = None
        # Distribution in action space
        self.transformed_dist: torch.distributions.MultivariateNormal = None

        self.distribution = self.generator_dist

        self.c: torch.Tensor = None
        self.G: torch.Tensor = None

    def proba_distribution(
            self,
            mean_actions: torch.Tensor,
            log_std: torch.Tensor,
            generators: np.ndarray,
            centers: np.ndarray,
    ) -> "GeneratorDiagGaussian":

        # action_std = (torch.zeros_like(mean_actions) + log_std.exp()).clamp_min(torch.finfo(mean_actions.dtype).eps)
        action_std = torch.ones_like(mean_actions) * log_std.exp()
        self.generator_dist = torch.distributions.Normal(mean_actions, action_std)

        if mean_actions.ndim == 1:
            mean_actions = mean_actions.unsqueeze(0)
            action_std = action_std.unsqueeze(0)

        self.c = torch.tensor(centers, dtype=mean_actions.dtype, device=mean_actions.device)
        self.G = torch.tensor(generators, dtype=mean_actions.dtype, device=mean_actions.device)
        if self.c.ndim == 1:
            self.c = self.c.unsqueeze(0)
            self.G = self.G.unsqueeze(0)

        mean_actions_tr = self.c + torch.bmm(self.G, mean_actions.unsqueeze(-1)).squeeze(-1)
        cov_tr = torch.bmm(self.G, torch.bmm(torch.diag_embed(action_std), self.G.transpose(1, 2)))
        self.transformed_dist = torch.distributions.MultivariateNormal(mean_actions_tr, cov_tr)

        self.distribution = self.generator_dist

        return self

    def log_prob(self, actions: torch.Tensor) -> torch.Tensor:
        if actions.ndim == 1:
            actions = actions.unsqueeze(0)
        a_tr = self.c + torch.bmm(self.G, actions.unsqueeze(-1)).squeeze(-1)
        return self.transformed_dist.log_prob(a_tr)

    def sample(self) -> torch.Tensor:
        return self.generator_dist.rsample()

    def entropy(self) -> torch.Tensor:
        return self.generator_dist.entropy()

    def mode(self) -> torch.Tensor:
        return self.generator_dist.mean

    def apply_masking(self, action_mask) -> None:
        # TODO: this method is required by the framework. So think about restructuring the code to actually make use of it.
        pass
