import numpy as np
import torch
import copy
import warnings
from typing import List

from mmcv.cnn.bricks.transformer import TransformerLayerSequence
from mmcv.utils import ext_loader

from mmengine.registry import MODELS
from .custom_base_transformer_layer import MyCustomBaseTransformerLayer
ext_module = ext_loader.load_ext(
    '_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward'])


@MODELS.register_module()
class DeformableAttentionEncoder(TransformerLayerSequence):

    """
    Attention with both self and cross
    Implements the decoder in DETR transformer.
    Args:
        return_intermediate (bool): Whether to return intermediate outputs.
        coder_norm_cfg (dict): Config of last normalization layer. Default：
            `LN`.
    """

    def __init__(self, *args,
                 **kwargs):
        super(DeformableAttentionEncoder, self).__init__(*args, **kwargs)


    def get_reference_points(self, bev_w, bev_h):
        y_positions, x_positions = torch.meshgrid(torch.arange(bev_h), torch.arange(bev_w), indexing='ij') 
        # Normalize the positions to the range (0, 1) 
        y_positions = y_positions.float() / (bev_h - 1) 
        x_positions = x_positions.float() / (bev_w - 1) 

        # Reshape and stack to get the final output 

        relative_positions = torch.stack((x_positions, y_positions), dim=-1) 

        return relative_positions.reshape(-1, 2) 


    def get_reference_points_original(self, bev_w, bev_h, dtype, device):
        ref_y, ref_x = torch.meshgrid(
            torch.linspace(0.5, bev_h-0.5, bev_h, dtype=dtype, device=device),
            torch.linspace(0.5, bev_w-0.5, bev_w, dtype=dtype, device=device)
        )

        ref_y = ref_y.reshape(-1) / bev_h
        ref_x = ref_x.reshape(-1) / bev_w
        ref_2d = torch.stack((ref_x, ref_y), -1)
        return ref_2d


    def forward(self,
                inputs: List[torch.Tensor], bev_queries, bev_h, bev_w, embed_dims, positional_encoding=None, *args, **kwargs):
        """Forward function for `TransformerDecoder`.
        Args:
            embed_dims应该是最终输出的维度，为了简便所有attnention操作共享一个embed_dims
        Returns:
            Tensor: Results with shape [1, num_query, bs, embed_dims] when
                return_intermediate is `False`, otherwise it has shape
                [num_layers, num_query, bs, embed_dims].
        """

        bs = inputs[0].shape[0]
        bev_queries = bev_queries.unsqueeze(0).repeat(bs, 1, 1)

        spatial_shapes = torch.tensor([[bev_h, bev_w]], dtype=torch.long).to(inputs[0].device)
        level_start_index = torch.tensor([0], dtype=torch.long).to(inputs[0].device)

        reference_points = self.get_reference_points(bev_w, bev_h)

        reference_points = self.get_reference_points_original(bev_w, bev_h, dtype=inputs[0].dtype, device=inputs[0].device)
        reference_points = reference_points.unsqueeze(0).repeat(bs, 1, 1).to(inputs[0].device)
        
        # b, h, w, c
        inputs = [t.permute(0, 2, 3, 1) for t in inputs]
        multi_modality_values = [t.clone() for t in inputs]

        for lid, layer in enumerate(self.layers):
            output = layer(
                bev_queries,
                *args,
                multi_modality_keys = inputs,
                multi_modality_values = multi_modality_values,
                multi_modality_inputs = inputs,
                bev_h=bev_h,
                bev_w=bev_w,
                spatial_shapes=spatial_shapes,
                level_start_index=level_start_index,
                reference_points=reference_points,
                positional_encoding=positional_encoding,
                **kwargs)

            bev_queries = output

        return output