import torch.nn as nn
import torch.nn.functional as F
import torch
import math
import re

DIM_INPUT_CNN = 32
CHANNEL_INPUT_CNN = 3
OUT_CHANNEL_CONV1_CNN = 16
KERNEL_SIZE_CONV1_CNN = 3
STRIDE_CONV1_CNN = 1
PADDING_CONV1_CNN = 1
OUT_CHANNEL_CONV2_CNN = 16
KERNEL_SIZE_CONV2_CNN = 3
STRIDE_CONV2_CNN = 1
PADDING_CONV2_CNN = 1
KERNEL_SIZE_MAXPOOLING2D_CNN = 2
STRIDE_MAXPOOLING2D_CNN = 2
DIM_LINEAR1_CNN = 100
DIM_OUTPUT_CNN = 10
DIM_INPUT_RNN = 28
DIM_HIDDEN_LSTM_RNN = 128
DIM_HIDDEN_LINEAR_RNN = 64
DIM_OUTPUT_RNN = 10
DIM_INPUT_MLP = 784
DIM_HIDDEN_MLP = 100
DIM_OUTPUT_MLP = 47


class CNN(nn.Module):

    def __init__(self,N,eta_theta0, c_theta0, gamma_theta0):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=CHANNEL_INPUT_CNN, \
                               out_channels=OUT_CHANNEL_CONV1_CNN, \
                               kernel_size=KERNEL_SIZE_CONV1_CNN, \
                               stride=STRIDE_CONV1_CNN, \
                               padding=PADDING_CONV1_CNN)
        self.pool = nn.MaxPool2d(kernel_size=KERNEL_SIZE_MAXPOOLING2D_CNN, \
                                 stride=STRIDE_MAXPOOLING2D_CNN)
        self.conv2 = nn.Conv2d(in_channels=OUT_CHANNEL_CONV1_CNN, \
                               out_channels=OUT_CHANNEL_CONV2_CNN, \
                               kernel_size=KERNEL_SIZE_CONV2_CNN, \
                               stride=STRIDE_CONV2_CNN, \
                               padding=PADDING_CONV2_CNN)
        self.outputdim1 = int(
            math.floor((DIM_INPUT_CNN + 2 * PADDING_CONV1_CNN - KERNEL_SIZE_CONV1_CNN) / STRIDE_CONV1_CNN + 1) / 2)
        self.outputdim2 = int(
            math.floor((self.outputdim1 + 2 * PADDING_CONV2_CNN - KERNEL_SIZE_CONV2_CNN) / STRIDE_CONV2_CNN + 1) / 2)
        self.linear1 = nn.Linear(OUT_CHANNEL_CONV2_CNN * self.outputdim2 * self.outputdim2, DIM_LINEAR1_CNN)
        self.linear2 = nn.Linear(DIM_LINEAR1_CNN, DIM_OUTPUT_CNN)
        self.outputdim = DIM_OUTPUT_CNN

        self.N = N
        self.pattern1 = re.compile(r'linear|conv')
        self.pattern2 = re.compile(r'lstm')

        for name, module in self._modules.items():
            if self.pattern1.match(name):
                size_w = module.weight.data.shape
                size_b = module.bias.data.shape
                module.register_buffer('v_w', torch.zeros(size_w))
                module.register_buffer('v_b', torch.zeros(size_b))
                module.register_buffer('eta_theta', torch.Tensor([eta_theta0 * self.N]))
                module.register_buffer('c_theta', torch.Tensor([c_theta0]))
                module.register_buffer('z_theta', module.c_theta)
                module.register_buffer('n_w', torch.zeros(size_w))
                module.register_buffer('n_b', torch.zeros(size_b))
            elif self.pattern2.match(name):
                size_wih = module.weight_ih_l0.data.shape
                size_bih = module.bias_ih_l0.data.shape
                size_whh = module.weight_hh_l0.data.shape
                size_bhh = module.bias_hh_l0.data.shape
                module.register_buffer('v_wih', torch.zeros(size_wih))
                module.register_buffer('v_bih', torch.zeros(size_bih))
                module.register_buffer('v_whh', torch.zeros(size_whh))
                module.register_buffer('v_bhh', torch.zeros(size_bhh))
                module.register_buffer('eta_theta', torch.Tensor([eta_theta0 * self.N]))
                module.register_buffer('c_theta', torch.Tensor([c_theta0]))
                module.register_buffer('gamma_theta', torch.Tensor([gamma_theta0]))
                module.register_buffer('eta_theta', torch.Tensor([eta_theta0 * self.N]))
                module.register_buffer('c_theta', torch.Tensor([c_theta0]))
                module.register_buffer('z_theta', module.c_theta)
                module.register_buffer('n_wih', torch.zeros(size_wih))
                module.register_buffer('n_bih', torch.zeros(size_bih))
                module.register_buffer('n_whh', torch.zeros(size_whh))
                module.register_buffer('n_bhh', torch.zeros(size_bhh))

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, OUT_CHANNEL_CONV2_CNN * self.outputdim2 * self.outputdim2)
        x = F.relu(self.linear1(x))
        x = self.linear2(x)
        return x


class LSTM(nn.Module):

    def __init__(self,N,eta_theta0, c_theta0, gamma_theta0):
        super(LSTM, self).__init__()
        self.lstm = nn.LSTM(DIM_INPUT_RNN, DIM_HIDDEN_LSTM_RNN, 1)
        self.linear1 = nn.Linear(DIM_HIDDEN_LSTM_RNN, DIM_HIDDEN_LINEAR_RNN)
        self.linear2 = nn.Linear(DIM_HIDDEN_LINEAR_RNN, DIM_OUTPUT_RNN)
        self.outputdim = DIM_OUTPUT_RNN

        self.N = N
        self.pattern1 = re.compile(r'linear|conv')
        self.pattern2 = re.compile(r'lstm')

        for name, module in self._modules.items():
            if self.pattern1.match(name):
                size_w = module.weight.data.shape
                size_b = module.bias.data.shape
                module.register_buffer('v_w', torch.zeros(size_w))
                module.register_buffer('v_b', torch.zeros(size_b))
                module.register_buffer('eta_theta', torch.Tensor([eta_theta0 * self.N]))
                module.register_buffer('c_theta', torch.Tensor([c_theta0]))
                module.register_buffer('z_theta', module.c_theta)
                module.register_buffer('n_w', torch.zeros(size_w))
                module.register_buffer('n_b', torch.zeros(size_b))
            elif self.pattern2.match(name):
                size_wih = module.weight_ih_l0.data.shape
                size_bih = module.bias_ih_l0.data.shape
                size_whh = module.weight_hh_l0.data.shape
                size_bhh = module.bias_hh_l0.data.shape
                module.register_buffer('v_wih', torch.zeros(size_wih))
                module.register_buffer('v_bih', torch.zeros(size_bih))
                module.register_buffer('v_whh', torch.zeros(size_whh))
                module.register_buffer('v_bhh', torch.zeros(size_bhh))
                module.register_buffer('eta_theta', torch.Tensor([eta_theta0 * self.N]))
                module.register_buffer('c_theta', torch.Tensor([c_theta0]))
                module.register_buffer('gamma_theta', torch.Tensor([gamma_theta0]))
                module.register_buffer('eta_theta', torch.Tensor([eta_theta0 * self.N]))
                module.register_buffer('c_theta', torch.Tensor([c_theta0]))
                module.register_buffer('z_theta', module.c_theta)
                module.register_buffer('n_wih', torch.zeros(size_wih))
                module.register_buffer('n_bih', torch.zeros(size_bih))
                module.register_buffer('n_whh', torch.zeros(size_whh))
                module.register_buffer('n_bhh', torch.zeros(size_bhh))

    def forward(self, x):
        x = torch.squeeze(x)
        x = torch.transpose(x, 0, 1)
        h0 = x.data.new(1, x.size(1), DIM_HIDDEN_LSTM_RNN).zero_()
        c0 = x.data.new(1, x.size(1), DIM_HIDDEN_LSTM_RNN).zero_()
        _, (x,_) = self.lstm(x, (h0, c0))
        x = torch.squeeze(x)
        x = F.relu(x)
        x = self.linear1(x)
        x = F.relu(x)
        x = self.linear2(x)
        return x


class MLP(nn.Module):

    def __init__(self,N,eta_theta0, c_theta0, gamma_theta0):
        super(MLP, self).__init__()
        self.linear1 = nn.Linear(DIM_INPUT_MLP, DIM_HIDDEN_MLP)
        self.linear2 = nn.Linear(DIM_HIDDEN_MLP, DIM_OUTPUT_MLP)
        self.outputdim = DIM_OUTPUT_MLP

        self.N = N
        self.pattern1 = re.compile(r'linear|conv')
        self.pattern2 = re.compile(r'lstm')

        for name, module in self._modules.items():
            if self.pattern1.match(name):
                size_w = module.weight.data.shape
                size_b = module.bias.data.shape
                module.register_buffer('v_w', torch.zeros(size_w))
                module.register_buffer('v_b', torch.zeros(size_b))
                module.register_buffer('eta_theta', torch.Tensor([eta_theta0 * self.N]))
                module.register_buffer('c_theta', torch.Tensor([c_theta0]))
                module.register_buffer('z_theta', module.c_theta)
                module.register_buffer('n_w', torch.zeros(size_w))
                module.register_buffer('n_b', torch.zeros(size_b))
            elif self.pattern2.match(name):
                size_wih = module.weight_ih_l0.data.shape
                size_bih = module.bias_ih_l0.data.shape
                size_whh = module.weight_hh_l0.data.shape
                size_bhh = module.bias_hh_l0.data.shape
                module.register_buffer('v_wih', torch.zeros(size_wih))
                module.register_buffer('v_bih', torch.zeros(size_bih))
                module.register_buffer('v_whh', torch.zeros(size_whh))
                module.register_buffer('v_bhh', torch.zeros(size_bhh))
                module.register_buffer('eta_theta', torch.Tensor([eta_theta0 * self.N]))
                module.register_buffer('c_theta', torch.Tensor([c_theta0]))
                module.register_buffer('gamma_theta', torch.Tensor([gamma_theta0]))
                module.register_buffer('eta_theta', torch.Tensor([eta_theta0 * self.N]))
                module.register_buffer('c_theta', torch.Tensor([c_theta0]))
                module.register_buffer('z_theta', module.c_theta)
                module.register_buffer('n_wih', torch.zeros(size_wih))
                module.register_buffer('n_bih', torch.zeros(size_bih))
                module.register_buffer('n_whh', torch.zeros(size_whh))
                module.register_buffer('n_bhh', torch.zeros(size_bhh))

    def forward(self, x):
        x = x.view(- 1, DIM_INPUT_MLP)
        x = self.linear1(x)
        x = F.relu(x)
        x = self.linear2(x)
        return x
