import numpy as np

import torch.nn as nn
import torch.nn.functional as F

# fmt: off
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# fmt: on


def get_activation(activation: str):
    """Get activation function.

    Args:
        activation (str): Activation function name.

    Raises:
        NotImplementedError: Activation function not implemented.

    Returns:
        nn.Module: Activation function.
    """
    possible_activations = [
        "relu",
        "tanh",
        "sigmoid",
        "softmax",
        "softplus",
        "softsign",
        "selu",
        "elu",
        "leaky_relu",
        "none",
    ]
    assert (
        activation in possible_activations
    ), f"Activation {activation} not implemented."

    if activation == "relu":
        return nn.ReLU()
    elif activation == "tanh":
        return nn.Tanh()
    elif activation == "sigmoid":
        return nn.Sigmoid()
    elif activation == "softmax":
        return nn.Softmax()
    elif activation == "softplus":
        return nn.Softplus()
    elif activation == "softsign":
        return nn.Softsign()
    elif activation == "selu":
        return nn.SELU()
    elif activation == "elu":
        return nn.ELU()
    elif activation == "leaky_relu":
        return nn.LeakyReLU()
    elif activation == "none":
        return nn.Identity()
    else:
        raise NotImplementedError


def create_mlp(
    input_dim: int,
    output_dim: int,
    hidden_dims: list,
    activation: str = "relu",
    output_activation: str = "linear",
    use_batch_norm: bool = False,
):
    """Create a multi-layer perceptron.

    Args:
        input_dim (int): Input dimension.
        output_dim (int): Output dimension.
        hidden_dims (list): List of hidden dimensions.
        activation (str, optional): Defaults to "relu". Activation function for hidden layers.
        output_activation (str, optional): Defaults to "linear". Activation function for output layer.

    Returns:
        nn.Sequential: Multi-layer perceptron.
    """
    layers = []
    dims = [int(input_dim)] + hidden_dims + [int(output_dim)]
    for i in range(len(dims) - 1):
        layers.append(nn.Linear(dims[i], dims[i + 1]))
        if i < len(dims) - 2:
            layers.append(get_activation(activation))
            if use_batch_norm:
                layers.append(nn.BatchNorm1d(dims[i + 1]))

        if not use_batch_norm and i == len(dims) - 3:
            layers.append(nn.LayerNorm(dims[i + 1]))

    layers.append(get_activation(output_activation))

    return nn.Sequential(*layers)


def create_cnn(
    input_dim: int,
    hidden_dims: list,
    activation: str = "relu",
    output_activation: str = "linear",
):
    current_res = input_dim[1:]
    input_dim = input_dim[0]

    layers = []
    dims = [input_dim] + hidden_dims
    for i in range(len(dims) - 1):
        layers.append(
            nn.Conv2d(dims[i], dims[i + 1], kernel_size=3, stride=1, padding="same")
        )
        if i < len(dims) - 2:
            layers.append(get_activation(activation))

        layers.append(
            nn.MaxPool2d(
                kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False
            )
        )

        current_res = [
            int(np.ceil(current_res[0] / 2)),
            int(np.ceil(current_res[1] / 2)),
        ]

    layers.append(nn.Flatten())
    layers.append(get_activation(output_activation))

    return nn.Sequential(*layers), current_res
