from dataclasses import dataclass
import torch
from torch import nn
from transformers.utils import ModelOutput
from transformers import PretrainedConfig
from typing import Optional


@dataclass
class FourierNeuralOperatorOutput(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    output: torch.FloatTensor = None


class FourierNeuralOperator2dConfig(PretrainedConfig):
    model_type = "fourier_neural_operator_2d"

    def __init__(
        self,
        image_size: int,
        num_channels: int,
        num_out_channels: int,
        num_modes: int,
        width: int,
        num_layers: int,
        use_conditioning: bool,
        padding=0,
        channel_slice_list_normalized_loss=None,
        hidden_act: Optional[str] = "leaky_relu",
        norm="Instance",
        p: Optional[int] = 1,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.image_size = image_size
        self.num_channels = num_channels
        self.num_out_channels = num_out_channels
        self.num_modes = num_modes
        self.width = width
        self.num_layers = num_layers
        self.padding = padding
        self.use_conditioning = use_conditioning
        self.channel_slice_list_normalized_loss = channel_slice_list_normalized_loss
        self.hidden_act = hidden_act
        self.p = p
        self.norm = norm


class Norm(nn.Module):
    def __init__(self, dim, norm="Instance"):
        super().__init__()
        if norm == "Batch":
            self.norm = nn.BatchNorm2d(dim)
        elif norm == "Instance":
            self.norm = nn.InstanceNorm2d(dim, affine=True)
        else:
            raise ValueError("Normalisation not supported")

    def forward(self, x, time):
        return self.norm(x)


class FILM(nn.Module):
    def __init__(self, dim, norm="Instance", intermediate=128):
        super(FILM, self).__init__()
        self.dim = dim

        self.inp2lat_scale = nn.Linear(
            in_features=1, out_features=intermediate, bias=True
        )
        self.lat2scale = nn.Linear(in_features=intermediate, out_features=dim)

        self.inp2lat_bias = nn.Linear(
            in_features=1, out_features=intermediate, bias=True
        )
        self.lat2bias = nn.Linear(in_features=intermediate, out_features=dim)

        self.inp2lat_scale.weight.data.fill_(0)
        self.lat2scale.weight.data.fill_(0)
        self.lat2scale.bias.data.fill_(1)

        self.inp2lat_bias.weight.data.fill_(0)
        self.lat2bias.weight.data.fill_(0)
        self.lat2bias.bias.data.fill_(0)

        if norm == "Batch":
            self.norm = nn.BatchNorm2d(dim)
        elif norm == "Instance":
            self.norm = nn.InstanceNorm2d(dim, affine=True)
        else:
            raise ValueError("Normalisation not supported")

    def forward(self, x, timestep):
        x = self.norm(x)
        timestep = timestep.reshape(-1, 1).type_as(x)
        scale = self.lat2scale(self.inp2lat_scale(timestep))
        bias = self.lat2bias(self.inp2lat_bias(timestep))
        scale = scale.unsqueeze(2).unsqueeze(3)
        scale = scale.expand_as(x)
        bias = bias.unsqueeze(2).unsqueeze(3).expand_as(x)
        return x * scale + bias


class SpectralConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, modes1, modes2):
        """
        2D Fourier layer. It does FFT, linear transform, and Inverse FFT.
        """
        super(SpectralConv2d, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes1 = (
            modes1  # Number of Fourier modes to multiply, at most floor(N/2) + 1
        )
        self.modes2 = modes2

        self.scale = 1 / (in_channels * out_channels)
        self.weights1 = nn.Parameter(
            self.scale
            * torch.rand(
                in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat
            )
        )
        self.weights2 = nn.Parameter(
            self.scale
            * torch.rand(
                in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat
            )
        )

    # Complex multiplication
    def compl_mul2d(self, input, weights):
        # (batch, in_channel, x,y ), (in_channel, out_channel, x,y) -> (batch, out_channel, x,y)
        return torch.einsum("bixy,ioxy->boxy", input, weights)

    def forward(self, x):
        batchsize = x.shape[0]
        # Compute Fourier coeffcients up to factor of e^(- something constant)
        x_ft = torch.fft.rfft2(x)

        # Multiply relevant Fourier modes
        out_ft = torch.zeros(
            batchsize,
            self.out_channels,
            x.size(-2),
            x.size(-1) // 2 + 1,
            dtype=torch.cfloat,
            device=x.device,
        )
        out_ft[:, :, : self.modes1, : self.modes2] = self.compl_mul2d(
            x_ft[:, :, : self.modes1, : self.modes2], self.weights1
        )
        out_ft[:, :, -self.modes1 :, : self.modes2] = self.compl_mul2d(
            x_ft[:, :, -self.modes1 :, : self.modes2], self.weights2
        )

        # Return to physical space
        x = torch.fft.irfft2(out_ft, s=(x.size(-2), x.size(-1)))
        return x


class FourierNeuralOperator2d(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.config = config

        if self.config.hidden_act == "leaky_relu":
            self.activation = nn.LeakyReLU()
        elif self.config.hidden_act == "relu":
            self.activation = nn.ReLU()
        else:
            raise ValueError(
                f"Activation function {self.config.hidden_act} not supported."
            )

        self.r = nn.Sequential(
            nn.Linear(
                (
                    self.config.num_channels
                    if not self.config.use_conditioning
                    else self.config.num_channels + 1
                ),
                128,
            ),
            self.activation,
            nn.Linear(128, self.config.width),
        )

        self.q = nn.Sequential(
            nn.Linear(self.config.width, 128),
            self.activation,
            nn.Linear(128, self.config.num_out_channels),
        )

        self.conv_list = nn.ModuleList(
            [
                nn.Conv2d(self.config.width, self.config.width, kernel_size=1)
                for _ in range(self.config.num_layers)
            ]
        )

        self.spectral_list = nn.ModuleList(
            [
                SpectralConv2d(
                    self.config.width,
                    self.config.width,
                    self.config.num_modes,
                    self.config.num_modes,
                )
                for _ in range(self.config.num_layers)
            ]
        )

        self.norm_list = nn.ModuleList(
            [
                (
                    FILM(self.config.width, self.config.norm)
                    if self.config.use_conditioning
                    else Norm(self.config.width, self.config.norm)
                )
                for _ in range(self.config.num_layers - 1)
            ]
        )

    def forward(
        self,
        pixel_values: Optional[torch.FloatTensor] = None,
        time: Optional[torch.FloatTensor] = None,
        pixel_mask: Optional[torch.BoolTensor] = None,
        labels: Optional[torch.FloatTensor] = None,
        return_dict: Optional[bool] = True,
    ):
        if pixel_values is None:
            raise ValueError("pixel_values cannot be None")

        if time is not None and self.config.use_conditioning:
            pixel_values = torch.cat(
                [
                    pixel_values,
                    time.reshape(-1, 1, 1, 1).expand(
                        -1, 1, pixel_values.shape[-2], pixel_values.shape[-1]
                    ),
                ],
                dim=1,
            )

        pixel_values = pixel_values.permute(0, 2, 3, 1)
        pixel_values = self.r(pixel_values)
        pixel_values = pixel_values.permute(0, 3, 1, 2)

        if self.config.padding > 0:
            pixel_values = nn.functional.pad(
                pixel_values, [0, self.config.padding, 0, self.config.padding]
            )

        for k, (s, c) in enumerate(zip(self.spectral_list, self.conv_list)):
            x1 = s(pixel_values)
            x2 = c(pixel_values)
            pixel_values = x1 + x2
            if k != self.config.num_layers - 1:
                pixel_values = self.norm_list[k](pixel_values, time)
                pixel_values = self.activation(pixel_values)

        del x1, x2

        if self.config.padding > 0:
            pixel_values = pixel_values[
                ..., : -self.config.padding, : -self.config.padding
            ]

        pixel_values = pixel_values.permute(0, 2, 3, 1)
        pixel_values = self.q(pixel_values)
        pixel_values = pixel_values.permute(0, 3, 1, 2)

        if pixel_mask is not None:
            pixel_values[pixel_mask] = labels[pixel_mask].type_as(pixel_values)

        loss = None
        if labels is not None:
            if self.config.p == 1:
                loss_fn = nn.functional.l1_loss
            elif self.config.p == 2:
                loss_fn = nn.functional.mse_loss
            else:
                raise ValueError("p must be 1 or 2")
            if self.config.channel_slice_list_normalized_loss is not None:
                loss = torch.mean(
                    torch.stack(
                        [
                            loss_fn(
                                pixel_values[
                                    :,
                                    self.config.channel_slice_list_normalized_loss[
                                        i
                                    ] : self.config.channel_slice_list_normalized_loss[
                                        i + 1
                                    ],
                                ],
                                labels[
                                    :,
                                    self.config.channel_slice_list_normalized_loss[
                                        i
                                    ] : self.config.channel_slice_list_normalized_loss[
                                        i + 1
                                    ],
                                ],
                            )
                            / (
                                loss_fn(
                                    labels[
                                        :,
                                        self.config.channel_slice_list_normalized_loss[
                                            i
                                        ] : self.config.channel_slice_list_normalized_loss[
                                            i + 1
                                        ],
                                    ],
                                    torch.zeros_like(
                                        labels[
                                            :,
                                            self.config.channel_slice_list_normalized_loss[
                                                i
                                            ] : self.config.channel_slice_list_normalized_loss[
                                                i + 1
                                            ],
                                        ]
                                    ),
                                )
                                + 1e-10
                            )
                            for i in range(
                                len(self.config.channel_slice_list_normalized_loss) - 1
                            )
                        ]
                    )
                )
            else:
                loss = loss_fn(pixel_values, labels)

        if not return_dict:
            output = (pixel_values,)
            return ((loss,) + output) if loss is not None else output

        return FourierNeuralOperatorOutput(loss=loss, output=pixel_values)
