import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.categorical import Categorical

from dcrl.utils.model_utils import layer_init


class ObsActorModel(nn.Module):
    def __init__(
        self,
        obs_space,
        action_space,
        recurrent=True,
        hidden_size_list=None,
        rnn_hidden_size=None,
        get_representation_net_func=None,
    ):
        super().__init__()

        self.recurrent = recurrent
        self.hidden_size_list = hidden_size_list
        self.rnn_hidden_size = rnn_hidden_size

        # Define representation net
        self.representation_net, self.tensor_type, self.rep_size = get_representation_net_func(obs_space["obs"])

        self.embedding_size = self.rep_size

        # Define memory
        if self.recurrent:
            self.memory_rnn = nn.LSTMCell(self.rep_size, self.rnn_hidden_size)
            for name, param in self.memory_rnn.named_parameters():
                if "bias" in name:
                    nn.init.constant_(param, 0)
                elif "weight" in name:
                    nn.init.orthogonal_(param, 1.0)
            self.embedding_size += self.rnn_hidden_size

        # Define policy net
        self.policy_net = []
        input_size = self.embedding_size
        for i in range(len(self.hidden_size_list)):
            self.policy_net.append(layer_init(nn.Linear(input_size, self.hidden_size_list[i])))
            self.policy_net.append(nn.ReLU())
            input_size = self.hidden_size_list[i]
        self.policy_net = nn.Sequential(*self.policy_net)

        self.policy_head = layer_init(nn.Linear(input_size, action_space.n), std=0.01)

    @property
    def memory_size(self):
        return 2 * self.rnn_hidden_size

    def forward(self, obs, memory=None):
        x = obs["obs"].type(self.tensor_type)
        embedding = self.representation_net(x)

        if self.recurrent:
            hidden = (memory[:, : self.rnn_hidden_size], memory[:, self.rnn_hidden_size :])
            hidden = self.memory_rnn(embedding, hidden)
            embedding = torch.cat([hidden[0], embedding], dim=1)
            memory = torch.cat(hidden, dim=1)

        output = self.policy_net(embedding)

        head_x = self.policy_head(output)
        dist = Categorical(logits=F.log_softmax(head_x, dim=1))

        return dist, memory
