from collections import OrderedDict
from copy import deepcopy
from typing import Dict, List, Optional, Tuple
import os

import numpy as np
import torch
import torch.distributed as dist
from mmengine.utils import is_list_of
from torch import Tensor
from torch.nn import functional as F

from mmdet3d.models import Base3DDetector
from mmdet3d.registry import MODELS
from mmdet3d.structures import Det3DDataSample
from mmdet3d.utils import OptConfigType, OptMultiConfig, OptSampleList
from .ops import Voxelization
from .utils import visualize_map, extract_pattern

from datetime import datetime

@MODELS.register_module()
class BEVFusion(Base3DDetector):

    def __init__(
        self,
        multi_modality: bool,
        data_preprocessor: OptConfigType = None,
        pts_voxel_encoder: Optional[dict] = None,
        pts_middle_encoder: Optional[dict] = None,
        fusion_layer: Optional[dict] = None,
        img_backbone: Optional[dict] = None,
        pts_backbone: Optional[dict] = None,
        view_transform: Optional[dict] = None,
        img_neck: Optional[dict] = None,
        pts_neck: Optional[dict] = None,
        bbox_head: Optional[dict] = None,
        init_cfg: OptMultiConfig = None,
        seg_head: Optional[dict] = None,
        cameraonly: bool = False,
        missing_camera: bool = False,
        missing_lidar: bool = False,
        original_voxelization: bool = True,
        shuffle_train: bool = False,
        map_visualize: bool = False,
        map_visual_num: int = 100,
        out_dir: str = None,
        **kwargs,
    ) -> None:
        voxelize_cfg = data_preprocessor.pop('voxelize_cfg')
        super().__init__(
            data_preprocessor=data_preprocessor, init_cfg=init_cfg)

        self.voxelize_reduce = voxelize_cfg.pop('voxelize_reduce')
        self.pts_voxel_layer = Voxelization(**voxelize_cfg)

        self.pts_voxel_encoder = MODELS.build(pts_voxel_encoder) if pts_voxel_encoder is not None else None

        self.img_backbone = MODELS.build(
            img_backbone) if img_backbone is not None else None
        self.img_neck = MODELS.build(
            img_neck) if img_neck is not None else None
        self.view_transform = MODELS.build(
            view_transform) if view_transform is not None else None
        self.pts_middle_encoder = MODELS.build(pts_middle_encoder) if pts_middle_encoder is not None else None

        self.fusion_layer = MODELS.build(
            fusion_layer) if fusion_layer is not None else None

        self.pts_backbone = MODELS.build(pts_backbone)
        self.pts_neck = MODELS.build(pts_neck)

        self.bbox_head = MODELS.build(bbox_head) if bbox_head is not None else None
        self.seg_head = MODELS.build(seg_head) if seg_head is not None else None

        self.init_weights()

        self.multi_modality = multi_modality

        self.cameraonly = cameraonly

        # 鲁棒性测试
        self.missing_lidar = missing_lidar
        self.missing_camera = missing_camera

        # 鲁棒性训练，如果是True，则会1/3概率触发lidar丢失，1/3概率触发camera丢失，1/3概率正常。注意不能都丢失
        # 注意这个只控制train，test的时候自动关掉
        self.shuffle_train = shuffle_train

        # voxelization控制
        self.original_voxelization = original_voxelization

        # 控制segmentation的可视化选项
        self.map_visualize = map_visualize
        self.map_visual_num = map_visual_num
        self.out_dir = out_dir

    def _forward(self,
                 batch_inputs: Tensor,
                 batch_data_samples: OptSampleList = None):
        """Network forward process.

        Usually includes backbone, neck and head forward without any post-
        processing.
        """
        pass

    def parse_losses(
        self, losses: Dict[str, torch.Tensor]
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """Parses the raw outputs (losses) of the network.

        Args:
            losses (dict): Raw output of the network, which usually contain
                losses and other necessary information.

        Returns:
            tuple[Tensor, dict]: There are two elements. The first is the
            loss tensor passed to optim_wrapper which may be a weighted sum
            of all losses, and the second is log_vars which will be sent to
            the logger.
        """
        log_vars = []
        for loss_name, loss_value in losses.items():
            if isinstance(loss_value, torch.Tensor):
                log_vars.append([loss_name, loss_value.mean()])
            elif is_list_of(loss_value, torch.Tensor):
                log_vars.append(
                    [loss_name,
                     sum(_loss.mean() for _loss in loss_value)])
            else:
                raise TypeError(
                    f'{loss_name} is not a tensor or list of tensors')

        loss = sum(value for key, value in log_vars if 'loss' in key)
        log_vars.insert(0, ['loss', loss])
        log_vars = OrderedDict(log_vars)  # type: ignore

        for loss_name, loss_value in log_vars.items():
            # reduce loss when distributed training
            if dist.is_available() and dist.is_initialized():
                loss_value = loss_value.data.clone()
                dist.all_reduce(loss_value.div_(dist.get_world_size()))
            log_vars[loss_name] = loss_value.item()

        return loss, log_vars  # type: ignore

    def init_weights(self) -> None:
        if self.img_backbone is not None:
            self.img_backbone.init_weights()

    @property
    def with_bbox_head(self):
        """bool: Whether the detector has a box head."""
        return hasattr(self, 'bbox_head') and self.bbox_head is not None

    @property
    def with_seg_head(self):
        """bool: Whether the detector has a segmentation head.
        """
        return hasattr(self, 'seg_head') and self.seg_head is not None

    def extract_img_feat(
        self,
        x,
        points,
        lidar2image,
        camera_intrinsics,
        camera2lidar,
        img_aug_matrix,
        lidar_aug_matrix,
        img_metas,
    ) -> torch.Tensor:
        B, N, C, H, W = x.size()
        x = x.view(B * N, C, H, W).contiguous()

        x = self.img_backbone(x)
        x = self.img_neck(x)

        if not isinstance(x, torch.Tensor):
            x = x[0]

        BN, C, H, W = x.size()
        x = x.view(B, int(BN / B), C, H, W)

        # with torch.autocast(device_type='cuda', dtype=torch.float32):
        #     x = self.view_transform(
        #         x,
        #         points,
        #         lidar2image,
        #         camera_intrinsics,
        #         camera2lidar,
        #         img_aug_matrix,
        #         lidar_aug_matrix,
        #         img_metas,
        #     )
        x = self.view_transform(
            x,
            points,
            lidar2image,
            camera_intrinsics,
            camera2lidar,
            img_aug_matrix,
            lidar_aug_matrix,
            img_metas,
        )
        
        return x

    def extract_pts_feat(self, batch_inputs_dict) -> torch.Tensor:

        if self.original_voxelization:
            points = batch_inputs_dict['points']
            with torch.autocast('cuda', enabled=False):
                points = [point.float() for point in points]
                feats, coords, sizes = self.voxelize(points)
                batch_size = coords[-1, 0] + 1
            x = self.pts_middle_encoder(feats, coords, batch_size)
            return x
        else:
            voxel_dict = batch_inputs_dict.get('voxels', None)
            voxel_features = self.pts_voxel_encoder(voxel_dict['voxels'],
                                                    voxel_dict['num_points'],
                                                    voxel_dict['coors'])
            batch_size = voxel_dict['coors'][-1, 0] + 1
            x = self.pts_middle_encoder(voxel_features, voxel_dict['coors'],
                                        batch_size)
            return x

    @torch.no_grad()
    def voxelize(self, points):
        feats, coords, sizes = [], [], []
        for k, res in enumerate(points):
            ret = self.pts_voxel_layer(res)
            if len(ret) == 3:
                # hard voxelize
                f, c, n = ret
            else:
                assert len(ret) == 2
                f, c = ret
                n = None
            feats.append(f)
            coords.append(F.pad(c, (1, 0), mode='constant', value=k))
            if n is not None:
                sizes.append(n)

        feats = torch.cat(feats, dim=0)
        coords = torch.cat(coords, dim=0)
        if len(sizes) > 0:
            sizes = torch.cat(sizes, dim=0)
            if self.voxelize_reduce:
                feats = feats.sum(
                    dim=1, keepdim=False) / sizes.type_as(feats).view(-1, 1)
                feats = feats.contiguous()

        return feats, coords, sizes

    def predict(self, batch_inputs_dict: Dict[str, Optional[Tensor]],
                batch_data_samples: List[Det3DDataSample],
                **kwargs) -> List[Det3DDataSample]:
        """Forward of testing.

        Args:
            batch_inputs_dict (dict): The model input dict which include
                'points' keys.

                - points (list[torch.Tensor]): Point cloud of each sample.
            batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
                Samples. It usually includes information such as
                `gt_instance_3d`.

        Returns:
            list[:obj:`Det3DDataSample`]: Detection results of the
            input sample. Each Det3DDataSample usually contain
            'pred_instances_3d'. And the ``pred_instances_3d`` usually
            contains following keys.

            - scores_3d (Tensor): Classification scores, has a shape
                (num_instances, )
            - labels_3d (Tensor): Labels of bboxes, has a shape
                (num_instances, ).
            - bbox_3d (:obj:`BaseInstance3DBoxes`): Prediction of bboxes,
                contains a tensor with shape (num_instances, 7).
        """
        batch_input_metas = [item.metainfo for item in batch_data_samples]
        feats = self.extract_feat(batch_inputs_dict, batch_input_metas, mode='predict')

        res = dict()
        if self.with_bbox_head:
            outputs = self.bbox_head.predict(feats, batch_input_metas)
            res = self.add_pred_to_datasample(batch_data_samples, outputs)

        # 将bbox head和seg head分开
        if self.with_seg_head:
            # seg得到结果不需要传入metas
            outputs = self.seg_head.predict(feats)
            # 新增方法将outputs写到data_samples中供解析
            res = self.add_predseg_to_datasample(batch_data_samples, outputs)

            if self.map_visualize and self.map_visual_num > 0:
                # 调用visualize
                for index, item in enumerate(res):
                    gt_masks = res[index].gt_masks_bev.astype(np.bool_)
                    pred_masks = res[index].pred_masks_bev.to('cpu').numpy() >= 0.5
                    name = extract_pattern(res[index].lidar_path)
                    visualize_map(os.path.join(self.out_dir, "map", f"{name}_gt.png"), gt_masks, classes=self.seg_head.classes)
                    visualize_map(os.path.join(self.out_dir, "map", f"{name}_pred.png"), pred_masks, classes=self.seg_head.classes)
                    self.map_visual_num -= 1

        return res


    def add_predseg_to_datasample(self, batch_data_samples, outputs):
        # val或者test的时候bs是多少，这里就会有多少个data_samples
        # batch_data_samples是一个list，长度为bs的值
        # 但是outputs的维度是b, classes, H, W，所以需要统一，统一成一个一个
        for i, data_sample in enumerate(batch_data_samples):
            data_sample.pred_masks_bev = outputs[i]

        return batch_data_samples
        

    def random_shuffle_input(self, batch_inputs_dict):
        #实现具体细节，1/3 1/2
        import random
        imgs = batch_inputs_dict.get('imgs', None)
        points = batch_inputs_dict.get('points', None)
        if points is not None and imgs is not None:
            B = len(points)
            lidar_flag = random.choices([0, 1], weights=[1, 2], k=1)[0]

            camera_flag = True
            if lidar_flag:
                camera_flag = random.choice([0, 1])

            if lidar_flag == 0:
                for index, value in enumerate(points):
                    points[index] = torch.zeros((1, 5), dtype=points[index].dtype, device=points[index].device)

            if camera_flag == 0:
                for index, value in enumerate(imgs):
                    imgs[index, :, :, :, :] = 0
            # 现在得到了两个list，根据list的值调整imgs和points
            # for index, value in enumerate(lidar_flag):
            #     if camera_flag[index] == 0:
            #         imgs[index, :, :, :, :] = 0
            #     if value == 0:
            #         points[index] = torch.zeros((1, 5), dtype=points[index].dtype, device=points[index].device)

            batch_inputs_dict['imgs'] = imgs
            batch_inputs_dict['points'] = points
        
        # 无论如何都return这个值
        return batch_inputs_dict


    def random_shuffle_onlylidar_input(self, batch_inputs_dict):
        #实现具体细节，1/3 1/3 1/3
        import random
        points = batch_inputs_dict.get('points', None)
        if points is not None:
            B = len(points)
            lidar_flag = random.choice([0, 1])

            # 现在得到了两个list，根据list的值调整imgs和points
            if lidar_flag == 0:
                for index, value in enumerate(points):
                    points[index] = torch.zeros((1, 5), dtype=points[index].dtype, device=points[index].device)

            batch_inputs_dict['points'] = points
        
        # 无论如何都return这个值
        return batch_inputs_dict


    def extract_feat(
        self,
        batch_inputs_dict,
        batch_input_metas,
        **kwargs,
    ):
        mode = kwargs.get('mode', 'default')

        imgs = batch_inputs_dict.get('imgs', None)
        points = batch_inputs_dict.get('points', None)
        points_for_camera = deepcopy(points)

        if mode == 'predict' and self.missing_lidar:
            B = len(points)
            for index in range(B):
                points[index] = torch.zeros((1, 5), dtype=points[index].dtype, device=points[index].device)
            batch_inputs_dict['points'] = points


        # 这种方式不会影响train的时候帮助bevdepth的深度估计的lidar
        if mode == 'train' and self.shuffle_train:
            batch_inputs_dict = self.random_shuffle_input(batch_inputs_dict)

        features = []
        
        if (self.multi_modality or self.cameraonly) and imgs is not None:
            imgs = imgs.contiguous()
            lidar2image, camera_intrinsics, camera2lidar = [], [], []
            img_aug_matrix, lidar_aug_matrix = [], []
            for i, meta in enumerate(batch_input_metas):
                lidar2image.append(meta['lidar2img'])
                camera_intrinsics.append(meta['cam2img'])
                camera2lidar.append(meta['cam2lidar'])
                img_aug_matrix.append(meta.get('img_aug_matrix', np.eye(4)))
                lidar_aug_matrix.append(
                    meta.get('lidar_aug_matrix', np.eye(4)))

            lidar2image = imgs.new_tensor(np.asarray(lidar2image))
            camera_intrinsics = imgs.new_tensor(np.array(camera_intrinsics))
            camera2lidar = imgs.new_tensor(np.asarray(camera2lidar))
            img_aug_matrix = imgs.new_tensor(np.asarray(img_aug_matrix))
            lidar_aug_matrix = imgs.new_tensor(np.asarray(lidar_aug_matrix))
            img_feature = self.extract_img_feat(imgs, points_for_camera,
                                                lidar2image, camera_intrinsics,
                                                camera2lidar, img_aug_matrix,
                                                lidar_aug_matrix,
                                                batch_input_metas)
            features.append(img_feature)

        if not self.cameraonly:
            pts_feature = self.extract_pts_feat(batch_inputs_dict)
            features.append(pts_feature)

        if self.multi_modality and self.fusion_layer is not None:
            # current_time = datetime.now().strftime('%Y%m%d_%H%M%S')
            # print(current_time)
            # lidar_name = f"{current_time}_lidar.pt"
            # torch.save(features[0], lidar_name)
            # camera_name = f"{current_time}_camera.pt"
            # torch.save(features[1], camera_name)
            # 0是image 1是lidar
            # if self.missing_camera:
            #     B, C_camera, H, W = features[0].shape
            #     features[0] = torch.zeros((B, C_camera, H, W),\
            #         dtype=features[0].dtype, device=features[0].device)
                    
            x = self.fusion_layer(features)
        else:
            assert len(features) == 1, features
            x = features[0]

        x = self.pts_backbone(x)
        x = self.pts_neck(x)

        return x

    def loss(self, batch_inputs_dict: Dict[str, Optional[Tensor]],
             batch_data_samples: List[Det3DDataSample],
             **kwargs) -> List[Det3DDataSample]:
        batch_input_metas = [item.metainfo for item in batch_data_samples]
        feats = self.extract_feat(batch_inputs_dict, batch_input_metas, mode='train')

        losses = dict()
        if self.with_bbox_head:
            bbox_loss = self.bbox_head.loss(feats, batch_data_samples)
            losses.update(bbox_loss)

        if self.with_seg_head:
            seg_loss = self.seg_head.loss(feats, batch_data_samples)
            losses.update(seg_loss)

        return losses
