from torch import nn, Tensor

from .activation import get_activation_fn


class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        hidden_units = [config.dim_in] + config.hidden_units + [config.dim_out]
        self.hidden_layers = nn.ModuleList(nn.Linear(hidden_units[i - 1], hidden_units[i])
                                           for i in range(1, len(hidden_units)))
        self.activation = get_activation_fn(config.activation)()
        self.dropouts = nn.ModuleList(nn.Dropout(config.dropout) for _ in range(1, len(hidden_units) - 1))

    def forward(self, x: Tensor) -> Tensor:
        output = x
        for i, layer in enumerate(self.hidden_layers):
            output = layer(output)
            if i < len(self.hidden_layers) - 1:
                output = self.activation(output)
                output = self.dropouts[i](output)
        return output
