import math
import torch
import torch.nn as nn

class RandomFeature(nn.Module):
    def __init__(self, d, m):
        super(RandomFeature, self).__init__()
        self.a = nn.Parameter(torch.randn(m)*0.0)
        self.W = torch.randn(m, d) #* math.sqrt(1.0/d)
        self.m = m
        
    def forward(self, x):
        o = torch.relu(x @ self.W.t()) @ self.a
        return o

def build_rfm(d=10, m=2000):
    return RandomFeature(d, m)


def build_linear_net(d=100, m=50):
    net = nn.Sequential(nn.Linear(d, m),
                   nn.Linear(m, m),
                   nn.Linear(m, m),
                   nn.Linear(m,1))
    return net

def build_fcn(m):
    d = 784
    net = nn.Sequential(nn.Flatten(),
                    nn.Linear(d,m), 
                    nn.ReLU(), 
                    nn.Linear(m,m), 
                    nn.ReLU(), 
                    nn.Linear(m,m),
                    nn.ReLU(),
                    nn.Linear(m,1))
    for layer in net._modules.values():
        if isinstance(layer, nn.Linear):
            fin = layer.weight.data.shape[1]
            layer.weight.data.normal_(0, math.sqrt(2. / fin))
            layer.bias.data.zero_()
    return net

def build_cnn(m):
    net = nn.Sequential(nn.Conv2d(1, m, kernel_size=3, padding=1), 
                        nn.ReLU(),
                        nn.Conv2d(m, 2*m, kernel_size=3, padding=1),
                        nn.ReLU(),
                        nn.AvgPool2d(kernel_size=2),
                        nn.Conv2d(2*m, 2*m, kernel_size=3, padding=1),
                        nn.ReLU(),
                        nn.Conv2d(2*m, m, kernel_size=3, padding=1),
                        nn.AvgPool2d(kernel_size=2),
                        nn.Flatten(),
                        nn.Linear(m*7*7, 1)
                    )

    for layer in net._modules.values():
        if isinstance(layer, nn.Conv2d):
            n = layer.kernel_size[0] * layer.kernel_size[1] * layer.in_channels
            layer.weight.data.normal_(0, math.sqrt(1. / n))
            # nn.init.kaiming_normal_(m.weight.data, nonlinearity='relu')
            layer.bias.data.zero_()
        elif isinstance(layer, nn.Linear):
            layer.weight.data.zero_()
            layer.bias.data.zero_()
    return net
