from spaghettini import quick_register

from torch import nn
import torch.nn.functional as F


@quick_register
class TwoLayerFC(nn.Module):
    def __init__(self, num_inputs=256, num_hidden=1000, num_outputs=784, activation=F.relu,
                 final_activation=lambda x: x):
        super().__init__()
        self.num_inputs = num_inputs
        self.num_hidden = num_hidden
        self.num_outputs = num_outputs
        self.activation = activation
        self.final_activation = final_activation

        self.fc1 = nn.Linear(self.num_inputs, self.num_hidden)
        self.fc2 = nn.Linear(self.num_hidden, self.num_outputs)

    def forward(self, x):
        if x.ndimension() > 2:
            x = x.view((x.shape[0], -1))
        z = self.fc1(x)
        z = self.activation(z)
        z = self.fc2(z)
        z = self.final_activation(z)

        return z, dict()
