import torch
import torch.nn as nn
from .perception import Perception, MLP
from .encoder import Encoder


class Critics(nn.Module):
    def __init__(self, state_dim, act_dim, goal_dim,
                 visual_perception=True, n_critics=2, hidden_size=512, n_hiddens=2):
        super().__init__()
        self.perception = Perception(128, 128, out_dim=state_dim, hidden_size=hidden_size) \
            if visual_perception else lambda x: x
        self.goal_encoder = Encoder(state_dim, goal_dim, hidden_size)

        self.critics = nn.ModuleList()
        hidden_sizes = [hidden_size] * n_hiddens
        for _ in range(n_critics):
            self.critics.append(MLP(state_dim + act_dim + goal_dim, *hidden_sizes, 1))

    def forward(self, obs, act, goal):
        emb_state = self.perception(obs)
        emb_goal = self.perception(goal)
        enc_goal = self.goal_encoder(emb_goal)
        return self.get_value_from_embeddings(emb_state, act, enc_goal)

    def encode(self, obs):
        emb_state = self.perception(obs)
        enc_goal = self.goal_encoder(emb_state)
        return enc_goal

    def get_value_from_embeddings(self, emb_state, act, enc_goal):
        x = torch.cat([emb_state, act, enc_goal], dim=-1)
        outputs = []
        for critic in self.critics:
            outputs.append(critic(x))
        return outputs