from typing import Optional, cast

import torch
import torch.nn.functional as F
from torch import nn

from .encoders import EncoderWithActionState
from abc import ABCMeta, abstractmethod


class DropEnergy(nn.Module, metaclass=ABCMeta):  # type: ignore
    @abstractmethod
    def forward(self, x: torch.Tensor, act: torch.Tensor, nx: torch.Tensor, e: torch.Tensor) -> torch.Tensor:
        pass

    def __call__(self, x: torch.Tensor, act: torch.Tensor, nx: torch.Tensor, e: torch.Tensor) -> torch.Tensor:
        return cast(torch.Tensor, super().__call__(x, act, nx, e))

    @abstractmethod
    def compute_error(
        self, x: torch.Tensor, action: torch.Tensor, nx: torch.Tensor, e: torch.Tensor
    ) -> torch.Tensor:
        pass

class DropEnergyFunction(DropEnergy, nn.Module):  # type: ignore
    _encoder: EncoderWithActionState
    _action_size: int
    _embedding_size: int
    # _fc: nn.Sequential

    def __init__(self, encoder: EncoderWithActionState, embedding_size: int):
        super().__init__()
        self._encoder = encoder
        self._action_size = encoder.action_size
        self._embedding_size = embedding_size
        # self._fc = nn.Linear(encoder.get_feature_size(), 1)
        self._fc = nn.Sequential(
            nn.Linear(encoder.get_feature_size()+self._embedding_size, 512),
            # nn.Dropout(0.2),
            nn.ReLU(), 
            nn.Linear(512, 512),
            # nn.Dropout(0.2),
            # nn.ReLU(),
            nn.Linear(512, 1)
        )

    def forward(self, x: torch.Tensor, action: torch.Tensor, nx: torch.Tensor, e: torch.Tensor) -> torch.Tensor:
        h = self._encoder(x, action, nx)
        return cast(torch.Tensor, self._fc(torch.cat([h, e], dim=1)))

    def compute_error(
        self,
        observations: torch.Tensor,
        actions: torch.Tensor,
        next_observations: torch.Tensor,
        embeddings: torch.Tensor,
        target: torch.Tensor,
        reduction: str = "mean",
    ) -> torch.Tensor:
        value = self.forward(observations, actions, next_observations, embeddings)
        # value = torch.exp(value)
        # print(value[:3].detach().cpu().numpy().tolist())
        loss = F.mse_loss(value, target, reduction=reduction)
        # loss = F.cross_entropy(value, target.flatten().long())
        return loss

    @property
    def action_size(self) -> int:
        return self._action_size

    @property
    def encoder(self) -> EncoderWithActionState:
        return self._encoder
