import functools

import torch
import torch.nn as nn
import torchvision

from src.simplex_layers import StandardConv, StandardLinear, SimplexLinear, SimplexConv


class CifarNet(nn.Module):
    def __init__(self, num_points, network_type, num_classes, seed):
        if network_type in ["FC_LAST", "POINT"]:
            conv_type = StandardConv
        else:
            conv_type = functools.partial(SimplexConv, num_endpoints=num_points, seed=seed)

        super(CifarNet, self).__init__()
        self.num_classes = num_classes
        self.max_pool = nn.MaxPool2d(2, 2)
        self.leaky_relu = torch.nn.LeakyReLU()
        self.flatten = nn.Flatten()
        self.conv1 = conv_type(in_channels=3, out_channels=16, kernel_size=5, padding=0, stride=1, bias=False).seed(seed)
        self.conv2 = conv_type(in_channels=16, out_channels=32, kernel_size=5, padding=1, stride=1, bias=False).seed(seed)
        self.conv3 = conv_type(in_channels=32, out_channels=64, kernel_size=3, padding=1, stride=1, bias=False).seed(seed)

        if network_type in ["FC_LAST", "POINT"]:
            self.fc1 = StandardLinear(in_features=64 * 3 * 3, out_features=128, bias=False).seed(seed)
        else:
            self.fc1 = SimplexLinear(num_endpoints=num_points, seed=seed, in_features=64 * 3 * 3, out_features=128, bias=False)
        self.fc2 = SimplexLinear(num_endpoints=num_points, seed=seed, in_features=128, out_features=num_classes, bias=False)

    def forward(self, x):
        x = self.max_pool(self.leaky_relu(self.conv1(x)))
        x = self.max_pool(self.leaky_relu(self.conv2(x)))
        x = self.max_pool(self.leaky_relu(self.conv3(x)))
        x = self.flatten(x)
        x = self.leaky_relu(self.fc1(x))
        x = self.fc2(x)
        return x

class FemnistCNN(nn.Module):
    def __init__(self, num_points, network_type, num_classes, seed):
        if network_type in ["FC_LAST", "POINT"]:
            conv_type = StandardConv
        else:
            conv_type = functools.partial(SimplexConv, num_endpoints=num_points, seed=seed)

        super(FemnistCNN, self).__init__()
        self.num_classes = num_classes
        self.max_pool = nn.MaxPool2d(2, 2)
        self.relu = torch.nn.ReLU()
        self.flatten = nn.Flatten()
        if network_type in ["FC_LAST", "POINT"]:
            self.conv1 = conv_type(in_channels=1, out_channels=16, kernel_size=5, padding=0, stride=1, bias=False).seed(seed)
            self.conv2 = conv_type(in_channels=16, out_channels=32, kernel_size=5, padding=0, stride=1, bias=False).seed(seed)
        else:
            self.conv1 = conv_type(in_channels=1, out_channels=16, kernel_size=5, padding=0, stride=1, bias=False)
            self.conv2 = conv_type(in_channels=16, out_channels=32, kernel_size=5, padding=0, stride=1, bias=False)

        self.fc = SimplexLinear(num_endpoints=num_points, seed=seed, in_features=512, out_features=num_classes, bias=False)

    def forward(self, x):
        x = self.max_pool(self.relu(self.conv1(x)))
        x = self.max_pool(self.relu(self.conv2(x)))
        x = self.flatten(x)
        x = self.fc(x)
        return x

def pretrained_resnet18(num_points, in_features, num_classes, seed):
    model = torchvision.models.resnet18(weights="IMAGENET1K_V1")
    model.fc = SimplexLinear(num_endpoints=num_points, seed=seed, in_features=in_features, out_features=num_classes, bias=False)
    return model

def pretrained_squeezenet1_0(num_points, num_classes, seed, num_channels):
    model = torchvision.models.squeezenet1_0(weights="IMAGENET1K_V1")
    if num_channels == 1:
        model.features[0] = nn.Conv2d(1, 96, kernel_size=7, stride=2)
    model.classifier = nn.Sequential(
        nn.Dropout(p=0.5, inplace=False),
        SimplexConv(num_endpoints=num_points, seed=seed, in_channels=512, out_channels=num_classes, kernel_size=1, bias=False),
        nn.ReLU(inplace=True),
        nn.AdaptiveAvgPool2d(output_size=(1, 1)),
    )
    return model


def net_fn(dataset_name, num_classes, network_arch, network_type, num_points, seed, device):
    if network_arch == "ScratchSimpleCNN":    
        if dataset_name == "FEMNIST":
            return FemnistCNN(num_points, network_type, num_classes, seed).to(device)
        elif dataset_name == "CIFAR10":
            return CifarNet(num_points, network_type, num_classes, seed).to(device)

    if network_arch == "PretrainedResNet18":
        return pretrained_resnet18(num_points, 512, num_classes, seed).to(device)

    if network_arch == "PretrainedSqueezeNet":
        num_channels = 1 if dataset_name == "FEMNIST" else 3
        return pretrained_squeezenet1_0(num_points, num_classes, seed, num_channels).to(device)

    raise NotImplementedError("Network not implemented")