from collections import OrderedDict

import torch
import torch.nn as nn
from torch.nn import functional as F
import torch.nn.utils.parametrize as P

from omegaconf import OmegaConf

from .cross_attention import CrossAttentionBlock, MultiheadCrossAttention
from .activations import create_activation
from .fourier_mapping import FourierMapping
from .configs import MultiBandDecoderWithCrossAttentionConfig


class RowNormalize(nn.Module):
    def forward(self, weight):
        return F.normalize(weight, dim=1)


class FourierMLPBlock(nn.Module):
    def __init__(
        self, input_dim, output_dim, ff_type, ff_dim, ff_sigma, ff_trainable=False, ff_sigma_min=None, bias=True
    ):
        super().__init__()
        self.fourier_mapping = FourierMapping(
            input_dim=input_dim,
            ff_type=ff_type,
            ff_dim=ff_dim,
            ff_sigma=ff_sigma,
            trainable=ff_trainable,
            ff_sigma_min=ff_sigma_min,
        )
        self.linear = nn.Linear(ff_dim * 2, output_dim, bias=bias)
        if bias:
            nn.init.zeros_(self.linear.bias)

        self.activation = nn.ReLU()

    def forward(self, coord):
        fourier_feats = self.fourier_mapping(coord)
        feats = self.linear(fourier_feats)
        feats = self.activation(feats)
        return feats


class Fuse(nn.Module):
    def __init__(self, query_dim, attn_out_dim, bias=True):
        super().__init__()
        self.query_dim = query_dim
        self.attn_out_dim = attn_out_dim

        self.attn_proj = nn.Linear(attn_out_dim, query_dim, bias=bias)
        if bias:
            nn.init.zeros_(self.attn_proj.bias)

        self.relu = nn.ReLU()

    def forward(self, query, attn_out):
        assert query.shape[-1] == self.query_dim
        assert attn_out.shape[-1] == self.attn_out_dim

        # Fuse query & attn_out
        attn_out_proj = self.attn_proj(attn_out)
        fused_feat = self.relu(query + attn_out_proj)
        return fused_feat


class MultiBandDecoderWithCrossAttention(nn.Module):
    def __init__(self, config: MultiBandDecoderWithCrossAttentionConfig):
        super().__init__()
        self.config = config
        self.num_mlp_layer = config.n_mlp_layer
        self.hidden_dims = list(config.hidden_dim)
        if len(self.hidden_dims) == 1:
            self.hidden_dims = [self.hidden_dims[0]] * (self.num_mlp_layer + 1)  # exclude output layer
        else:
            assert len(self.hidden_dims) == self.num_mlp_layer + 1

        self.num_mod_layer = config.n_mod_layer
        self.output_from_every_layer = self.config.output_from_every_layer
        self.use_first_query_for_init_hidden = self.config.use_first_query_for_init_hidden

        # cross-attn layer definition
        xattn_config = self.config.cross_attention
        self.cross_attention = MultiheadCrossAttention(
            embed_dim=xattn_config.embed_dim,
            n_head=xattn_config.n_head,
            input_dim=self.hidden_dims[0],
            context_dim=self.config.latent_dim,
            output_dim=xattn_config.embed_dim,
            dropout=xattn_config.dropout,
            bias=xattn_config.bias,
        )
        attn_out_dim = xattn_config.embed_dim

        attn_ff_config = self.config.attn_fourier_mapping
        self.attn_query = FourierMLPBlock(
            input_dim=self.config.input_dim,
            output_dim=self.hidden_dims[0],
            ff_sigma=attn_ff_config.ff_sigma,
            ff_sigma_min=attn_ff_config.ff_sigma_min,
            ff_type=attn_ff_config.type,
            ff_dim=attn_ff_config.ff_dim,
            ff_trainable=attn_ff_config.trainable,
            bias=self.config.use_bias,
        )

        # layer definitions
        in_dims = self.hidden_dims[:-1]
        out_dims = self.hidden_dims[1:]

        def create_to_query_layer(output_dim, ff_sigma, ff_sigma_min):
            ff_config = config.query_fourier_mapping
            to_query_layer = FourierMLPBlock(
                input_dim=self.config.input_dim,
                output_dim=output_dim,
                ff_sigma=ff_sigma,
                ff_sigma_min=ff_sigma_min,
                ff_type=ff_config.type,
                ff_dim=ff_config.ff_dim,
                ff_trainable=ff_config.trainable,
                bias=self.config.use_bias,
            )
            return to_query_layer

        mlp_layers = []
        to_query_layers = []
        fuse_layers = []
        out_proj_layers = []

        for layer_idx, (in_dim, out_dim) in enumerate(zip(in_dims, out_dims)):

            if layer_idx < self.num_mod_layer:
                ff_sigma = config.query_fourier_mapping.ff_sigma[layer_idx]
                ff_sigma_min = config.query_fourier_mapping.ff_sigma_min[layer_idx]
                to_query = create_to_query_layer(output_dim=in_dim, ff_sigma=ff_sigma, ff_sigma_min=ff_sigma_min)
                
                fuse = Fuse(in_dim, attn_out_dim, bias=xattn_config.bias)

                to_query_layers.append(to_query)
                fuse_layers.append(fuse)

            layer = nn.Linear(in_dim, out_dim, bias=self.config.use_bias)
            if self.config.use_bias:
                nn.init.zeros_(layer.bias)
            if config.normalize_mlp_weights:
                P.register_parametrization(layer, "weight", RowNormalize())

            mlp_layers.append(layer)

            is_last_layer = layer_idx == self.num_mlp_layer - 1
            if self.output_from_every_layer or is_last_layer:
                out_proj = nn.Linear(out_dim, self.config.output_dim, bias=self.config.use_bias)
                if self.config.use_bias:
                    nn.init.zeros_(out_proj.bias)
                out_proj_layers.append(out_proj)

        self.mlp_layers = nn.ModuleList(mlp_layers)
        self.to_query_layers = nn.ModuleList(to_query_layers)
        self.fuse_layers = nn.ModuleList(fuse_layers)
        self.out_proj_layers = nn.ModuleList(out_proj_layers)

        self.activation = create_activation(self.config.activation)
        self.output_bias = config.output_bias

    def forward(self, coord, latents):
        """Computes the signal value for each coordinate.
        Note: `assert outputs.shape[:-1] == coord.shape[:-1]`

        Args
            coord (torch.Tensor): Input coordinates.
            latents (torch.Tensor): Latent vectors to be cross-attended.
                Currently, all cross-attention layers uses the same latents as context.

        Returns
            outputs (torch.Tensor): evaluated values by INR
        """

        batch_size, coord_shape, input_dim = coord.shape[0], coord.shape[1:-1], coord.shape[-1]
        coord = coord.view(batch_size, -1, input_dim)  # flatten the coordinates

        attn_query = self.attn_query(coord)
        attn_out = self.cross_attention(attn_query, latents)

        hidden = None
        all_outputs = []

        for layer_idx, layer in enumerate(self.mlp_layers):

            if layer_idx < self.num_mod_layer:
                query = self.to_query_layers[layer_idx](coord)
                fused_feat = self.fuse_layers[layer_idx](query=query, attn_out=attn_out)

                if hidden is None:
                    if self.use_first_query_for_init_hidden:
                        hidden = fused_feat + query
                    else:
                        hidden = fused_feat
                else:
                    hidden = hidden + fused_feat

            hidden = layer(hidden)
            hidden = self.activation(hidden)

            is_last_layer = layer_idx == len(self.mlp_layers) - 1
            if self.output_from_every_layer:
                out = self.out_proj_layers[layer_idx](hidden)
                all_outputs.append(out)
            else:
                if is_last_layer:
                    out = self.out_proj_layers[-1](hidden)
                    all_outputs.append(out)

        outputs = torch.stack(all_outputs, dim=0).sum(dim=0)
        outputs = outputs + self.output_bias
        outputs = outputs.view(batch_size, *coord_shape, -1)
        return outputs

    def compute_modulated_params_dict(self, modulation_params_dict):
        raise NotImplementedError

    def forward_with_params(self, coord, params_dict):
        raise NotImplementedError
