import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch.nn.modules.utils import _pair

from .layers import Conv2d, MinMax, Scale, Sequential


class BasicRes(nn.Module):
    def __init__(self,
                 planes: int,
                 input_size: int,
                 activation: nn.Module = MinMax(dim=1),
                 **kwargs) -> None:
        super(BasicRes, self).__init__()

        self.residual = Sequential(
            Conv2d(planes, planes, bias=True, kernel_size=3), activation,
            Conv2d(planes, planes, bias=True, kernel_size=3),
            Scale(init_scale=0.0))

    def forward(self, x: Tensor) -> Tensor:
        out = self.residual(x)
        return x + out

    def lipschitz(self) -> Tensor:
        lc = 1 + self.residual.lipschitz()
        return lc


class LinearRes(nn.Module):
    """Linear Residual Block for 2D (BCHW) inputs.

    Args:
        planes (int): the number of input/output channels.
        input_size (int): the input size, i.e., H and W.
        depth (int): the number of linear residual blocks used.
        use_affine (bool): if true, add an affine layer after the weight layer.
            Defaults to True.
    """
    def __init__(self,
                 planes: int,
                 input_size: int,
                 depth: int,
                 use_affine: bool = True,
                 **kwargs) -> None:
        super(LinearRes, self).__init__()
        self.planes = planes

        weight = torch.randn(planes, planes, 3, 3) / (planes * 9)
        self.weight = nn.Parameter(weight)

        if use_affine:
            affine = torch.ones(planes)
            affine = affine.reshape(-1, 1, 1, 1)
            self.affine = nn.Parameter(affine)
        else:
            self.affine = 1.0
        self.bias = nn.Parameter(torch.zeros(planes))

        identity = torch.zeros(planes, planes, 3, 3)
        identity[:, :, 1, 1] = torch.eye(planes)
        self.register_buffer('identity', identity)
        self.scale = depth**-.5

        init_x = torch.ones(1, planes, *_pair(input_size))
        self.register_buffer('init_x', init_x)

    def forward(self, x: Tensor) -> Tensor:
        weight = self.get_weight()
        out = F.conv2d(x, weight, self.bias, padding=1)
        return out

    def lipschitz(self) -> Tensor:
        weight = self.get_weight()
        x = self.init_x.data
        for _ in range(self.num_lc_iter):
            x = F.conv2d(x, weight, bias=None, padding=1)
            x = F.conv_transpose2d(x, weight, bias=None, padding=1)
            x = F.normalize(x, dim=(1, 2, 3))

        self.init_x += (x - self.init_x).detach()
        x = F.conv2d(x, weight, bias=None, padding=1)
        return x.norm()

    def get_weight(self) -> Tensor:
        weight = self.weight.mul(self.affine)
        weight = self.identity + weight * self.scale
        return weight

    def extra_repr(self) -> str:
        string = f'{self.planes}, kernel_size=3, stride=1, padding=1'
        if self.bias is None:
            string += ', bias=False'
        return string
