from typing import Union

import numpy as np

from omegaconf import ListConfig

import torch
import torch.nn as nn
import torch.nn.functional as F

from control.base_networks import create_mlp

# fmt: off
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# fmt: on

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class DiscreteValueNetwork(nn.Module):
    def __init__(
        self,
        input_dim: Union[int, ListConfig],
        hidden_dims: list = [128, 64],
        activation: str = "relu",
        output_activation: str = "none",
        use_batch_norm: bool = False,
        ortho_init: bool = False,
        *args,
        **kwargs,
    ):
        """Discrete Q-network function approximator.

        Args:
            input_dim (int): Input dimension of the state.
            output_dim (int): Output dimension of the action.
            hidden_dims (list, optional): List of hidden dimensions. Defaults to [128, 64].
            activation (str, optional): Activation function for hidden layers. Defaults to "relu".
            output_activation (str, optional): Activation function for output layer.
                Defaults to "none" which yields linear activation.
        """
        super().__init__(*args, **kwargs)
        self.value_network = create_mlp(
            input_dim=input_dim,
            output_dim=1,
            hidden_dims=hidden_dims,
            activation=activation,
            output_activation=output_activation,
            use_batch_norm=use_batch_norm,
        )

        if ortho_init:
            self.apply(self.init_weights)

    def init_weights(self, m):
        if type(m) == nn.Linear:
            nn.init.orthogonal_(m.weight, nn.init.calculate_gain("relu"))
            nn.init.constant_(m.bias, 0)

    def forward(self, x):
        return self.value_network(x)
