# this can be changed to any student config in openmmlab
_base_ = [
    'mmdet3d::vexkd/student_lidar_centerpoint_epoch20_detection.py',
]

student = _base_.model
# TODO change this path to the teacher checkpoint path
teacher_ckpt = ''  # noqa: E501
model=dict(
    _scope_='mmrazor',
    _delete_=True,
    type='BEVQueryGuidedDistillCascadeTeacherAssist',

    architecture=student,
    teacher=dict(
        cfg_path=
        'mmdet3d::vexkd/teacher_bevfusion_mgfm_segmentation.py',  # noqa: E501
        pretrained=False
    ),
    teacher_ckpt=teacher_ckpt,
    cascade_prompt=['before_pts_backbone', 'before_head'],
    distiller=dict(
        type='BEVQueryGuidedMultiLayerDistiller',
        student_recorders=dict(pts_backbone_in=dict(type='ModuleInputs', source='pts_backbone'), \
            fpn_head_in=dict(type='ModuleInputs', source='seg_head')),  # noqa: E501
        teacher_recorders=dict(pts_backbone_in=dict(type='ModuleInputs', source='pts_backbone'), \
            fpn_head_in=dict(type='ModuleInputs', source='seg_head'), \
            bev_queries=dict(type='Parameter', source='fusion_layer.bev_embedding.weight')),  # noqa: E501

        distill_losses=dict(
            loss_pts_backbone=dict(
                type='BEVQueryGuidedDeformableMultiLayerAttentionTransferLoss',
                encoder=dict(
                    type='DeformableAttentionEncoder',
                    num_layers=3,
                    transformerlayers=dict(
                        type='MM_BEVFormerLayer',
                        num_modality=1, 
                        attn_cfgs=[
                            dict(
                                type='SpatialDeformableAttention',
                                init_dims=256,
                                embed_dims=256,
                                batch_first=True,
                                num_points=16,
                                num_levels=1
                            )],
                        ffn_cfgs=dict(
                            type='FFN',
                            embed_dims=256,
                            feedforward_channels=512,
                            num_fcs=1,
                            ffn_drop=0.1,
                            act_cfg=dict(type='ReLU', inplace=True),
                        ),
                        feedforward_channels=512,
                        operation_order=('cross_attn', 'norm', 'ffn', 'norm')
                    )),
                loss_weight=1.0,
                embed_dims=256
            ),
            loss_head=dict(
                type='BEVQueryGuidedDeformableMultiLayerAttentionTransferLoss',
                encoder=dict(
                    type='DeformableAttentionEncoder',
                    num_layers=3,
                    transformerlayers=dict(
                        type='MM_BEVFormerLayer',
                        num_modality=1, 
                        attn_cfgs=[
                            dict(
                                type='SpatialDeformableAttention',
                                init_dims=512,
                                embed_dims=512, 
                                batch_first=True,
                                num_points=16,
                                num_levels=1
                            )],
                        ffn_cfgs=dict(
                            type='FFN',
                            embed_dims=512,
                            feedforward_channels=512,
                            num_fcs=1,
                            ffn_drop=0.1,
                            act_cfg=dict(type='ReLU', inplace=True),
                        ),
                        feedforward_channels=512,
                        operation_order=('cross_attn', 'norm', 'ffn', 'norm') 
                    )),
                loss_weight=2.0,
                embed_dims=512)),
        loss_forward_mappings=dict(
            loss_pts_backbone=dict(
                preds_S=dict(from_student=True, recorder='pts_backbone_in', data_idx=0),
                preds_T=dict(from_student=False, recorder='pts_backbone_in', data_idx=0),
                bev_queries=dict(from_student=False, recorder='bev_queries')),
            loss_head=dict(
                preds_S=dict(from_student=True, recorder='fpn_head_in', data_idx=0),
                preds_T=dict(from_student=False, recorder='fpn_head_in', data_idx=0),
                bev_queries=dict(from_student=False, recorder='bev_queries'))
        )
    )
)

map_classes = [
    'drivable_area', 'ped_crossing', 'walkway', 'stop_line', 
    'carpark_area', 'divider'
]

# use the union set of the data pipeline to align the teacher and student input.
train_dataloader = dict(
    _delete_=True,
    batch_size=4,
    dataset=dict(
        dataset=dict(
            ann_file='nuscenes_infos_train.pkl',
            box_type_3d='LiDAR',
            data_prefix=dict(
                CAM_BACK='samples/CAM_BACK',
                CAM_BACK_LEFT='samples/CAM_BACK_LEFT',
                CAM_BACK_RIGHT='samples/CAM_BACK_RIGHT',
                CAM_FRONT='samples/CAM_FRONT',
                CAM_FRONT_LEFT='samples/CAM_FRONT_LEFT',
                CAM_FRONT_RIGHT='samples/CAM_FRONT_RIGHT',
                pts='samples/LIDAR_TOP',
                sweeps='sweeps/LIDAR_TOP'),
            data_root='data/nuscenes/',
            metainfo=dict(classes=[
                'car',
                'truck',
                'construction_vehicle',
                'bus',
                'trailer',
                'barrier',
                'motorcycle',
                'bicycle',
                'pedestrian',
                'traffic_cone',
            ]),
            modality=dict(use_camera=True, use_lidar=True),
            pipeline=[
                dict(
                    backend_args=None,
                    color_type='color',
                    to_float32=True,
                    type='BEVLoadMultiViewImageFromFiles'),
                dict(
                    backend_args=None,
                    coord_type='LIDAR',
                    load_dim=5,
                    type='LoadPointsFromFile',
                    use_dim=5),
                dict(
                    backend_args=None,
                    load_dim=5,
                    pad_empty_sweeps=True,
                    remove_close=True,
                    sweeps_num=9,
                    type='LoadPointsFromMultiSweeps',
                    use_dim=5),
                dict(
                    type='LoadAnnotations3D',
                    with_attr_label=False,
                    with_bbox_3d=True,
                    with_label_3d=True),
                dict(
                    db_sampler=dict(
                        classes=[
                            'car',
                            'truck',
                            'construction_vehicle',
                            'bus',
                            'trailer',
                            'barrier',
                            'motorcycle',
                            'bicycle',
                            'pedestrian',
                            'traffic_cone',
                        ],
                        data_root='data/nuscenes/',
                        info_path='data/nuscenes/nuscenes_dbinfos_train.pkl',
                        points_loader=dict(
                            backend_args=None,
                            coord_type='LIDAR',
                            load_dim=5,
                            type='LoadPointsFromFile',
                            use_dim=[
                                0,
                                1,
                                2,
                                3,
                                4,
                            ]),
                        prepare=dict(
                            filter_by_difficulty=[
                                -1,
                            ],
                            filter_by_min_points=dict(
                                barrier=5,
                                bicycle=5,
                                bus=5,
                                car=5,
                                construction_vehicle=5,
                                motorcycle=5,
                                pedestrian=5,
                                traffic_cone=5,
                                trailer=5,
                                truck=5)),
                        rate=1.0,
                        sample_groups=dict(
                            barrier=2,
                            bicycle=6,
                            bus=4,
                            car=2,
                            construction_vehicle=7,
                            motorcycle=6,
                            pedestrian=2,
                            traffic_cone=2,
                            trailer=6,
                            truck=3)),
                    type='ObjectSample'),
                dict(
                    bot_pct_lim=[
                        0.0,
                        0.0,
                    ],
                    final_dim=[
                        256,
                        704,
                    ],
                    is_train=True,
                    rand_flip=True,
                    resize_lim=[
                        0.38,
                        0.55,
                    ],
                    rot_lim=[
                        -5.4,
                        5.4,
                    ],
                    type='ImageAug3D'),
                dict(
                    rot_range=[
                        -0.78539816,
                        0.78539816,
                    ],
                    scale_ratio_range=[
                        0.9,
                        1.1,
                    ],
                    translation_std=0.5,
                    type='BEVFusionGlobalRotScaleTrans'),
                dict(
                    type='LoadBEVSegmentation',
                    dataset_root='data/nuscenes/',
                    xbound=[-50.0, 50.0, 0.5],
                    ybound=[-50.0, 50.0, 0.5],
                    classes=map_classes
                ),
                dict(type='BEVFusionRandomFlip3D'),
                dict(
                    point_cloud_range=[
                        -54.0,
                        -54.0,
                        -5.0,
                        54.0,
                        54.0,
                        3.0,
                    ],
                    type='PointsRangeFilter'),
                dict(
                    point_cloud_range=[
                        -54.0,
                        -54.0,
                        -5.0,
                        54.0,
                        54.0,
                        3.0,
                    ],
                    type='ObjectRangeFilter'),
                dict(
                    classes=[
                        'car',
                        'truck',
                        'construction_vehicle',
                        'bus',
                        'trailer',
                        'barrier',
                        'motorcycle',
                        'bicycle',
                        'pedestrian',
                        'traffic_cone',
                    ],
                    type='ObjectNameFilter'),
                dict(
                    fixed_prob=True,
                    max_epoch=6,
                    mode=1,
                    offset=False,
                    prob=0.0,
                    ratio=0.5,
                    rotate=1,
                    type='GridMask',
                    use_h=True,
                    use_w=True),
                dict(type='PointShuffle'),
                dict(
                    keys=[
                        'points',
                        'img',
                        'gt_bboxes_3d',
                        'gt_labels_3d',
                        'gt_bboxes',
                        'gt_labels',
                    ],
                    meta_keys=[
                        'cam2img',
                        'ori_cam2img',
                        'lidar2cam',
                        'lidar2img',
                        'cam2lidar',
                        'ori_lidar2img',
                        'img_aug_matrix',
                        'box_type_3d',
                        'sample_idx',
                        'lidar_path',
                        'img_path',
                        'transformation_3d_flow',
                        'pcd_rotation',
                        'pcd_scale_factor',
                        'pcd_trans',
                        'img_aug_matrix',
                        'lidar_aug_matrix',
                        'num_pts_feats',
                    ],
                    type='Pack3DDetInputs'),
            ],
            test_mode=False,
            type='NuScenesDataset',
            use_valid_flag=True),
        type='CBGSDataset'),
    num_workers=4,
    persistent_workers=True,
    sampler=dict(shuffle=False, type='DefaultSampler'))

train_pipeline = [
    dict(
        backend_args=None,
        color_type='color',
        to_float32=True,
        type='BEVLoadMultiViewImageFromFiles'),
    dict(
        backend_args=None,
        coord_type='LIDAR',
        load_dim=5,
        type='LoadPointsFromFile',
        use_dim=5),
    dict(
        backend_args=None,
        load_dim=5,
        pad_empty_sweeps=True,
        remove_close=True,
        sweeps_num=9,
        type='LoadPointsFromMultiSweeps',
        use_dim=5),
    dict(
        type='LoadAnnotations3D',
        with_attr_label=False,
        with_bbox_3d=True,
        with_label_3d=True),
    dict(
        bot_pct_lim=[
            0.0,
            0.0,
        ],
        final_dim=[
            256,
            704,
        ],
        is_train=True,
        rand_flip=True,
        resize_lim=[
            0.38,
            0.55,
        ],
        rot_lim=[
            -5.4,
            5.4,
        ],
        type='ImageAug3D'),
    dict(
        rot_range=[
            -0.78539816,
            0.78539816,
        ],
        scale_ratio_range=[
            0.9,
            1.1,
        ],
        translation_std=0.5,
        type='BEVFusionGlobalRotScaleTrans'),
    dict(
        type='LoadBEVSegmentation',
        dataset_root='data/nuscenes/',
        xbound=[-50.0, 50.0, 0.5],
        ybound=[-50.0, 50.0, 0.5],
        classes=map_classes
    ),
    dict(type='BEVFusionRandomFlip3D'),
    dict(
        point_cloud_range=[
            -54.0,
            -54.0,
            -5.0,
            54.0,
            54.0,
            3.0,
        ],
        type='PointsRangeFilter'),
    dict(
        point_cloud_range=[
            -54.0,
            -54.0,
            -5.0,
            54.0,
            54.0,
            3.0,
        ],
        type='ObjectRangeFilter'),
    dict(
        classes=[
            'car',
            'truck',
            'construction_vehicle',
            'bus',
            'trailer',
            'barrier',
            'motorcycle',
            'bicycle',
            'pedestrian',
            'traffic_cone',
        ],
        type='ObjectNameFilter'),
    dict(type='PointShuffle'),
    dict(
        keys=[
            'points',
            'img',
            'gt_bboxes_3d',
            'gt_labels_3d',
            'gt_bboxes',
            'gt_labels',
        ],
        meta_keys=[
            'cam2img',
            'ori_cam2img',
            'lidar2cam',
            'lidar2img',
            'cam2lidar',
            'ori_lidar2img',
            'img_aug_matrix',
            'box_type_3d',
            'sample_idx',
            'lidar_path',
            'img_path',
            'transformation_3d_flow',
            'pcd_rotation',
            'pcd_scale_factor',
            'pcd_trans',
            'img_aug_matrix',
            'lidar_aug_matrix',
            'num_pts_feats',
        ],
        type='Pack3DDetInputs'),
]

find_unused_parameters = True
train_cfg = dict(val_interval=1)

custom_imports=dict(imports=['mmrazor.engine.hooks.stop_mask_learning_epoch_hook'], allow_failed_imports=False)
custom_hooks = [
    dict(_scope_='mmrazor', type='StopMaskLearningIterHook', stop_epoch=2)
]
