from copy import deepcopy
from typing import List

import numpy as np
import scipy
import torch

from action_masking.rlsampling.integration.geometric import geometric_integration_gaussian
from action_masking.rlsampling.sampling import BaseZonoRandomWalkSampler, GaussianRDHRSampler
from action_masking.rlsampling.sets.zonotope import Zonotope, zonotope_contains_batch
from action_masking.rlsampling.torch_distributions.convex_set_normal import ConvexSetNormal


class ZonotopeNormal(ConvexSetNormal):
    def __init__(
        self,
        loc: torch.Tensor,
        covariance_matrix: torch.Tensor,
        zonotopes: List[Zonotope] = None,
        generator: np.ndarray = None,
        center: np.ndarray = None,
        SamplingAlgorithm: BaseZonoRandomWalkSampler = GaussianRDHRSampler,
        validate_args: bool = False,
    ):
        if zonotopes is None:
            if generator is None or center is None:
                raise AttributeError("Either zonotope or generator and center must be provided.")

            sets = [Zonotope(generator[i], center[i]) for i in range(self.batch_shape[0])]
        else:
            sets = zonotopes

        super().__init__(loc, covariance_matrix, sets, validate_args=validate_args)
        self._sets: List[Zonotope] = sets

        # Add batch dimension
        if loc.dim() == 1:
            loc = loc.unsqueeze(0)
        if covariance_matrix.dim() == 2:
            covariance_matrix = covariance_matrix.unsqueeze(0)

        self.loc_np = loc.cpu().detach().numpy()
        self.cov_mat_np = covariance_matrix.cpu().detach().numpy()

        self._samplers = [
            SamplingAlgorithm(self._sets[i], mean=self.loc_np[i, :, np.newaxis], cov=self.cov_mat_np[i])
            for i in range(len(self._sets))
        ]

    @property
    def mode(self) -> torch.Tensor:
        mode = deepcopy(self.loc)
        if mode.ndim == 1:
            mode = mode.unsqueeze(0)
        contains = zonotope_contains_batch(self._sets, self.loc_np)

        for i in range(len(self._sets)):
            if not contains[i]:
                bp = torch.tensor(
                    self._sets[i].boundary_point(
                        direction=self.loc_np[i, :, np.newaxis] - self._sets[i].c, point=self._sets[i].c
                    )
                )

                direction = bp - self._sets[i].c
                mode[i] = (torch.tensor(self._sets[i].c) + 0.95 * direction).squeeze(-1)

        return mode

    def sample(self, sample_shape=torch.Size()):
        shape = self._extended_shape(sample_shape)

        # print(f"sample dist loc {self.loc}")
        # print(f"sample dist var {self.variance}")

        assert len(sample_shape) < 2  # Simplification

        n_samples = 1 if not sample_shape else sample_shape[0]

        def sample_one_zono(zono_idx: int, n_samples: int) -> torch.Tensor:
            nonlocal counter
            try:
                # samples = torch.zeros((n_samples, self.d))
                samples = np.zeros((n_samples, self.d))

                p = self._samplers[zono_idx].sample(walk_length=self.d**3 - 1)
                for i in range(n_samples):
                    p = self._samplers[zono_idx].sample(walk_length=self.d, start_point=p)
                    samples[i] = p[:, 0]

                if n_samples == 1:
                    samples = samples[0]

                return torch.tensor(samples)
            except ValueError as e:
                print(f"Error in sample: {e}")
                counter += 1
                if counter > 10:
                    raise ValueError(e)
                return sample_one_zono(zono_idx, n_samples)

        counter = 0

        # return (n, d)
        if len(self._sets) == 1:
            return sample_one_zono(0, n_samples)

        # return (n, #Z, d)
        else:
            samples = torch.zeros(shape)
            if sample_shape:
                for z_idx in range(len(self._sets)):
                    samples[:, z_idx] = sample_one_zono(z_idx, n_samples)
            else:
                for z_idx in range(len(self._sets)):
                    samples[z_idx] = sample_one_zono(z_idx, n_samples)

        return samples

    def log_prob(self, value: torch.Tensor):
        # TODO: You could move all this logic into the zono_contains_batch_functionality
        min_eps_log = -10.0

        normal_log_probs = super(ZonotopeNormal, self).normal_log_prob(value)
        normal_log_probs = torch.clamp(normal_log_probs, min=min_eps_log)
        return normal_log_probs

        if self._normalizing_constant is None:
            self.compute_normalizing_constant()

        # print(f"NC {self._normalizing_constant}")
        # print(f"Value {value}")
        # print(f"Value.shape {value.shape}")

        log_probs = normal_log_probs - self._normalizing_constant.log()

        return log_probs

        value_np = value.cpu().detach().numpy()

        # if value.ndim == 2 and len(self._sets) == 1:
        #     # 1 zono for batch of values
        #     value = value.unsqueeze(1)
        #     value_np = value_np[:, np.newaxis, :]

        # (d,) -> (#Z,)
        if value.ndim == 1:
            # TODO: You could batch this also, but the case, where this could be batched never happens during RL training
            mask = torch.zeros_like(log_probs, dtype=torch.bool)
            for idx, zono in enumerate(self._sets):
                mask[idx] = not zono.contains_point(value_np[:, np.newaxis])
            log_probs[mask] = min_eps_log

        # (#Z, d) -> (#Z,) OR (n, d) -> (n,)
        elif value.ndim == 2:
            # (n, d)
            # mask = torch.logical_not(torch.tensor(zonotope_contains_batch(self._sets, value_np)))
            # log_probs[mask] = min_eps_log

            # if len(self._sets) > 1:
            #     # Multiple zonos (--> no batching)
            #     mask = torch.zeros_like(log_probs, dtype=torch.bool)
            #     for idx, zono in enumerate(self._sets):
            #         mask[idx] = not zono.contains_point(value_np[idx])
            #     log_probs[mask] = min_eps_log
            # else:
            #     # One zono (--> batching)
            #     mask = torch.logical_not(torch.tensor(self._sets[0].contains_points_batch(value_np[..., 0])))
            #     log_probs[mask] = min_eps_log

            if len(self._sets) == 1:
                zonos = self._sets * value.shape[0]
            else:
                zonos = self._sets

            mask = torch.logical_not(torch.tensor(zonotope_contains_batch(zonos, value_np)))
            log_probs[mask] = min_eps_log

        # (n, #Z, d) -> (n, #Z)
        # elif log_probs.ndim == 3:
        elif value.ndim == 3:
            for idx, zono in enumerate(self._sets):
                # TODO: This could be the error!!!
                mask = torch.logical_not(zonotope_contains_batch([zono for _ in range(value_np.shape[0])], value_np[:, idx]))
                log_probs[:, idx][mask] = min_eps_log

        else:
            raise ValueError(f"The shape of value {value.shape} is not supported")

        print(f"log_probs {log_probs}")

        return log_probs
    
    def compute_normalizing_constant(self) -> torch.Tensor:
        if self._normalizing_constant is not None:
            return self._normalizing_constant
        
        # return 1

        assert (
            self.d <= 4
        ), "The current (geometric) integration technique cannot handle dimensinos higher than 4 in reasonable time."

        self._normalizing_constant = torch.ones(len(self._sets))

        for i in range(len(self._sets)):
            dist = scipy.stats.multivariate_normal(self.loc_np[i], self.cov_mat_np[i])
            try:
                self._normalizing_constant[i] = torch.tensor(
                    geometric_integration_gaussian(self._sets[i], dist, abs_error=1e-3)
                )
            except:
                print("Error in geometric integration. Setting value to 1.")

        return self._normalizing_constant

    def get_actions(self, deterministic: bool = False) -> torch.Tensor:
        if deterministic:
            actions = torch.zeros(self.n_batch, self.d)

            for i in range(self.n_batch):
                zono = self._zonotopes[i]
                mode = self.loc_np[i, :, np.newaxis]

                if zono.contains_point(mode):
                    actions[i] = self.loc[i]

                else:
                    # TODO: This is only accurate for a diagonal cov.
                    direction = mode - zono._c
                    point = zono.boundary_point(direction)
                    actions[i] = torch.tensor(point[:, 0])

            return actions

        else:
            return self.sample(sample_shape=self.batch_shape)


# def NormalizingConstantFunction(torch.autograd.Function):
#     @staticmethod
#     def forward(ctx, loc: torch.tensor, cov_mat: torch.tensor, sets: List[Zonotope]):
#         n, d = input.shape
#         assert d <= 4
#         ctx.save_for_backward(input)

#         loc_np = loc.cpu().detach().numpy()
#         cov_mat_np = cov_mat.cpu().detach().numpy()

#         nc = torch.ones(n)

#         for i in range(n):
#             dist = scipy.stats.multivariate_normal(loc_np[i], cov_mat_np[i])
#             try:
#                 nc[i] = torch.tensor(
#                     geometric_integration_gaussian(sets[i], dist, abs_error=1e-3)
#                 )
#             except:
#                 print("Error in geometric integration. Setting value to 1.")

#         return nc

#     @staticmethod
#     def backward(ctx, grad_output):
#         loc, cov_mat = ctx.saved_tensors

        # Nevermind! You cannot compute that, since the function which you would
        # have to integrate has a multivariate output!

        # gard_cov 



