# Copyright (c) OpenMMLab. All rights reserved.
from typing import List

import torch
from torch import Tensor
from torch import nn as nn

from mmdet3d.models.layers.torchsparse import IS_TORCHSPARSE_AVAILABLE
from mmdet3d.registry import MODELS
from mmdet3d.structures.det3d_data_sample import SampleList
from .decode_head import Base3DDecodeHead

if IS_TORCHSPARSE_AVAILABLE:
    from torchsparse import SparseTensor
else:
    SparseTensor = None


@MODELS.register_module()
class MinkUNetHead(Base3DDecodeHead):
    r"""MinkUNet decoder head with TorchSparse backend.

    Refer to `implementation code <https://github.com/mit-han-lab/spvnas>`_.

    Args:
        channels (int): The input channel of conv_seg.
        num_classes (int): Number of classes.
    """

    def __init__(self, channels: int, num_classes: int, **kwargs) -> None:
        super().__init__(channels, num_classes, **kwargs)

    def build_conv_seg(self, channels: int, num_classes: int,
                       kernel_size: int) -> nn.Module:
        """Build Convolutional Segmentation Layers."""
        return nn.Linear(channels, num_classes)

    def _stack_batch_gt(self, batch_data_samples: SampleList) -> Tensor:
        """Concat voxel-wise Groud Truth."""
        gt_semantic_segs = [
            data_sample.gt_pts_seg.voxel_semantic_mask
            for data_sample in batch_data_samples
        ]
        return torch.cat(gt_semantic_segs)

    def predict(self, inputs: SparseTensor,
                batch_data_samples: SampleList) -> List[Tensor]:
        """Forward function for testing.

        Args:
            inputs (SparseTensor): Features from backone.
            batch_data_samples (List[:obj:`Det3DDataSample`]): The seg
                data samples.

        Returns:
            List[Tensor]: The segmentation prediction mask of each batch.
        """
        seg_logits = self.forward(inputs)

        batch_idx = inputs.C[:, -1]
        seg_logit_list = []
        for i, data_sample in enumerate(batch_data_samples):
            seg_logit = seg_logits[batch_idx == i]
            seg_logit = seg_logit[data_sample.voxel2point_map]
            seg_logit_list.append(seg_logit)

        return seg_logit_list

    def forward(self, x: SparseTensor) -> Tensor:
        """Forward function.

        Args:
            x (SparseTensor): Features from backbone.

        Returns:
            Tensor: Segmentation map of shape [N, C].
                Note that output contains all points from each batch.
        """
        output = self.cls_seg(x.F)
        return output
