import torch
import torch.nn as nn
from torch.distributions.bernoulli import Bernoulli
from .distributions import DiscreteMixLogistic
from .perception import Perception, MLP
from .encoder import Encoder
from .skill_decoder import _DiscreteLogHead

MIN_LOG_STD = -20.
MAX_LOG_STD = 2.


class DeterministicActor(nn.Module):
    def __init__(self, state_dim, act_dim, goal_dim,
                 visual_perception=True, hidden_size=512, n_hiddens=2):
        super().__init__()
        self.perception = Perception(128, 128, out_dim=state_dim, hidden_size=hidden_size) \
            if visual_perception else lambda x: x
        self.goal_encoder = Encoder(state_dim, goal_dim, hidden_size)

        hidden_sizes = [hidden_size] * n_hiddens
        self.layers = MLP(state_dim + goal_dim, *hidden_sizes, act_dim)

    def forward(self, obs, goal):
        emb_state = self.perception(obs)
        emb_goal = self.perception(goal)
        enc_goal = self.goal_encoder(emb_goal)
        return self.get_pi_from_embeddings(emb_state, enc_goal)

    def encode(self, obs):
        emb_state = self.perception(obs)
        enc_goal = self.goal_encoder(emb_state)
        return enc_goal

    def get_pi_from_embeddings(self, emb_state, enc_goal):
        x = torch.cat([emb_state, enc_goal], dim=-1)
        x = self.layers(x)
        return torch.tanh(x)


class ManipulationActor(nn.Module):
    def __init__(self, state_dim, act_dim, goal_dim, visual_perception=True, n_mixtures=10, hidden_size=1024):
        super().__init__()
        self.perception = Perception(128, 128, out_dim=state_dim) if visual_perception else lambda x: x
        self.goal_encoder = Encoder(state_dim, goal_dim, hidden_size=hidden_size)
        self.layers = nn.Sequential(
            nn.Linear(state_dim + goal_dim, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size)
        )
        self.fc_eef = _DiscreteLogHead(hidden_size, act_dim - 1, n_mixtures)
        self.fc_gripper = nn.Linear(hidden_size, 1)
        self.n_mixtures = n_mixtures
        self.act_dim = act_dim

    def forward(self, obs, goal):
        emb_state = self.perception(obs)
        emb_goal = self.perception(goal)
        enc_goal = self.goal_encoder(emb_goal)
        return self.get_pi_from_embeddings(emb_state, enc_goal)

    def encode(self, obs):
        emb_state = self.perception(obs)
        enc_goal = self.goal_encoder(emb_state)
        return enc_goal

    def get_pi_from_embeddings(self, emb_state, enc_goal):
        x = torch.cat([emb_state, enc_goal], dim=-1)
        x = self.layers(x)
        eef_dist, gripper_dist = self.output_layer(x)
        eef = eef_dist.mean
        gripper = torch.where(gripper_dist.logits > 0, 1., -1.)
        return torch.cat([eef, gripper], dim=-1)

    def get_log_prob_from_embeddings(self, emb_state, act, enc_goal):
        eef, gripper = torch.split(act, (self.act_dim - 1, 1), dim=-1)
        gripper = torch.where(gripper == 1., 1., 0.)
        x = torch.cat([emb_state, enc_goal], dim=-1)
        x = self.layers(x)
        eef_dist, gripper_dist = self.output_layer(x)
        logp_eef = eef_dist.log_prob(eef)
        logp_gripper = gripper_dist.log_prob(gripper)
        log_prob = logp_eef.mean(-1) + logp_gripper.mean(-1)
        return log_prob

    def output_layer(self, x):
        mu, ln_scale, logit_prob = self.fc_eef(x)
        eef_dist = DiscreteMixLogistic(mu, ln_scale, logit_prob, self.n_mixtures)
        gripper_logit = self.fc_gripper(x)
        gripper_dist = Bernoulli(logits=gripper_logit)
        return eef_dist, gripper_dist

    
