from dlisa.model.m3dref_clip import M3DRefCLIP
import torch
import hydra
from dlisa.common_ops.functions import common_ops
from dlisa.model.vision_module.pointgroup import PointGroupNMS
from dlisa.model.cross_modal_module.match_module import SpatialMatchModule
from dlisa.model.vision_module.object_renderer import SizeDependentObjectRenderer
from dlisa.util.utils import nms


class Dlisa(M3DRefCLIP):
    def __init__(self, cfg):
        super(Dlisa, self).__init__(cfg)
        self.save_hyperparameters()
        
        self.stop_training=False
        self.multiple_optimizer=False
        self.use_optimal_inference_thres=False

        if cfg.model.inference.use_optimal_inference_thres:
            self.use_optimal_inference_thres=True
            self.optimal_inference_thres = 0
            self.inference_thres_list = cfg.model.inference.output_threshold_list

        if cfg.model.network.use_nms and cfg.model.network.nms.mode=='front':
            self.detector = PointGroupNMS(
            input_channel=self.detector_input_channel, output_channel=cfg.model.network.detector.output_channel,
            max_proposals=cfg.model.network.max_num_proposals, semantic_class=cfg.data.semantic_class,
            use_gt=cfg.model.network.detector.use_gt_proposal, use_pt2=cfg.model.network.detector.use_pt2,
            freeze_pt2=cfg.model.network.detector.freeze_pt2, dynamic_box=cfg.model.network.use_dynamic_box,
            **cfg.model.network.dynamic_box_module, pt2_config=cfg.model.network.pointnet2
        )
        
        if cfg.model.network.learnable_camera_pose and cfg.model.network.use_2d_feature:
            self.multiple_optimizer = True
            if cfg.model.network.camera_pose_generator.mode == 'size':
                self.object_renderer = SizeDependentObjectRenderer(camera_pose_generator=cfg.model.network.camera_pose_generator, **cfg.model.network.object_renderer)

            
        feature_dim = cfg.model.network.detector.output_channel * self.hparams.cfg.model.network.use_3d_features + \
            self.hparams.cfg.model.network.use_2d_feature * self.hparams.cfg.model.network.clip_img_encoder.output_channel + \
            self.hparams.cfg.model.network.use_global_feature * self.hparams.cfg.model.network.global_feature_encoder.global_feature_channel + \
            self.hparams.cfg.model.network.use_semantic_feature * 20
        
        if cfg.model.network.matching_mode == 'spatial':
            self.match_module = SpatialMatchModule(
                **cfg.model.network.matching_spatial_module,
                input_channel=feature_dim
            )

        if self.multiple_optimizer:
            self.automatic_optimization=False

        self.count_parameters()

        if cfg.model.network.use_dynamic_box:
            self.dynamic_loss = hydra.utils.instantiate(cfg.model.loss.dynamic_loss)

        # evaluator
        self.evaluator = hydra.utils.instantiate(cfg.data.evaluator)
        self.val_step_outputs = []
        self.test_step_outputs = []

    def count_parameters(self):
            print("Number of parameters:", sum(p.numel() for p in self.parameters() if p.requires_grad))

    def forward(self, data_dict):
        output_dict = self.detector(data_dict)
        batch_size = len(data_dict["scene_id"])
        if self.hparams.cfg.model.network.use_3d_features:
            aabb_features = output_dict["aabb_features"]
        else:
            aabb_features = torch.empty(
                size=(output_dict["aabb_features"].shape[0], 0),
                dtype=output_dict["aabb_features"].dtype, device=self.device
            )
        
        self.text_encoder(data_dict, output_dict)
        data_dict["lang_attention_mask"] = None


        if self.hparams.cfg.model.network.use_semantic_feature:
            semantic_features = output_dict["semantic_features"]
            aabb_features = torch.cat((aabb_features, semantic_features), dim=1)
            output_dict["aabb_features"] = aabb_features

        if self.hparams.cfg.model.network.use_2d_feature:
            rendered_imgs = self.object_renderer(data_dict, output_dict) # [view * box_num, 224, 224, 3] / # [view * csize * box_num, 224, 224, 3]
            img_features = self.clip_image(rendered_imgs.permute(dims=(0, 3, 1, 2)))
            views = len(self.hparams.cfg.model.network.object_renderer.eye)
            aabb_img_features = torch.nn.functional.avg_pool1d(
                img_features.permute(1, 0), kernel_size=views, stride=views
            ).permute(1, 0)

            aabb_features = torch.nn.functional.normalize(torch.cat((aabb_features, aabb_img_features), dim=1), dim=1)


        output_dict["aabb_features"] = common_ops.convert_sparse_tensor_to_dense(
            aabb_features, output_dict["proposal_batch_offsets"],
            self.hparams.cfg.model.network.max_num_proposals
        )

        output_dict["pred_aabb_min_max_bounds"] = common_ops.convert_sparse_tensor_to_dense(
            output_dict["pred_aabb_min_max_bounds"].reshape(-1, 6), output_dict["proposal_batch_offsets"],
            self.hparams.cfg.model.network.max_num_proposals
        ).reshape(batch_size, self.hparams.cfg.model.network.max_num_proposals, 2, 3)

        if self.hparams.cfg.model.network.use_spatial_feature:
            output_dict["aabb_features"] = self.spatial_relation_encoder(data_dict, output_dict)

        """
        cross-modal fusion
        """
        self.match_module(data_dict, output_dict)
        return output_dict
    

    def _loss(self, data_dict, output_dict):
        loss_dict = self.detector.loss(data_dict, output_dict)
        
        # reference loss
        loss_dict["reference_loss"] = self.ref_loss(
            output_dict,
            output_dict["pred_aabb_min_max_bounds"],
            output_dict["pred_aabb_scores"],
            data_dict["gt_aabb_min_max_bounds"],
            data_dict["gt_target_obj_id_mask"].permute(dims=(1, 0)),
            data_dict["aabb_count_offsets"],
        )

        # contrastive loss
        if self.hparams.cfg.model.network.use_contrastive_loss:
       
            loss_dict["contrastive_loss"] = self.contrastive_loss(
                output_dict["aabb_features_inter"],
                output_dict["sentence_features"],
                output_dict["gt_labels"]
            )
        
        # dynamic_box regularization
        if self.hparams.cfg.model.network.use_dynamic_box:
            loss_dict["dynamic_loss"] = self.dynamic_loss(
                output_dict["thres_score"]
            )
        return loss_dict

    
    def training_step(self, data_dict, idx):
        output_dict = self(data_dict)
        loss_dict = self._loss(data_dict, output_dict)

        # calculate the total loss and log
        total_loss = 0
        for loss_name, loss_value in loss_dict.items():
            total_loss += loss_value
            self.log(f"train_loss/{loss_name}", loss_value, on_step=True, on_epoch=False)
        self.log(f"train_loss/total_loss", total_loss, on_step=True, on_epoch=False)

        # log number of boxes:
        box_number = output_dict["proposal_masks_dense"].float().sum(dim=1).mean().item()
        self.log(f"train_stats/avg_box_num", box_number, on_step=True, on_epoch=False)

        if self.multiple_optimizer:
            for opt in self.optimizers():
                opt.zero_grad()

            self.manual_backward(total_loss)

            for opt in self.optimizers():
                opt.step()

        return total_loss
    

    def validation_step(self, data_dict, idx):
        output_dict = self(data_dict)
        loss_dict = self._loss(data_dict, output_dict)
        # calculate the total loss and log
        total_loss = 0
        for loss_name, loss_value in loss_dict.items():
            total_loss += loss_value
            self.log(f"val_loss/{loss_name}", loss_value, on_step=False, on_epoch=True)
        self.log(f"val_loss/total_loss", total_loss, on_step=False, on_epoch=True)

        # log number of boxes:
        box_number = output_dict["proposal_masks_dense"].float().sum(dim=1).mean().item()
        self.log(f"val_stats/avg_box_num", box_number, on_step=False, on_epoch=True)

        # get predictions and gts
        self.val_step_outputs.append((self._parse_pred_results_val(data_dict, output_dict), self._parse_gt(data_dict)))


    def test_step(self, data_dict, idx):
        output_dict = self(data_dict)
        self.test_step_outputs.append(
            (self._parse_pred_results_test(data_dict, output_dict), self._parse_gt(data_dict))
        )


    def on_validation_epoch_end(self):
        total_pred_results = {}
        total_gt_results = {}
        for pred_results, gt_results in self.val_step_outputs:
            total_gt_results.update(gt_results)

            if self.use_optimal_inference_thres:
                for thres, info in pred_results.items():
                    if thres not in total_pred_results.keys():
                        total_pred_results[thres] = dict()
                    total_pred_results[thres].update(info)
            else:
                total_pred_results.update(pred_results)

        self.val_step_outputs.clear()
        self.evaluator.set_ground_truths(total_gt_results)
        
        if self.use_optimal_inference_thres:
            results, thres = self.evaluator.evaluate(total_pred_results)
            self.optimal_inference_thres = thres
        else:
            results = self.evaluator.evaluate(total_pred_results)

        # log
        for metric_name, result in results.items():
            for breakdown, value in result.items():
                self.log(f"val_eval/{metric_name}_{breakdown}", value)
        
        if self.use_optimal_inference_thres:
            self.log("val_eval/optimal_thres", thres)

        if self.hparams.cfg.scheduled_job:
            self.stop_training = True


    def on_test_epoch_end(self):
        total_pred_results = {}
        total_gt_results = {}
        for pred_results, gt_results in self.test_step_outputs:
            total_gt_results.update(gt_results)
            total_pred_results.update(pred_results)
        self.test_step_outputs.clear()
        self._save_predictions(total_pred_results)


    # Uncomment if using NMS at the end 
    def _parse_pred_results_val(self, data_dict, output_dict):
        batch_size, lang_chunk_size = data_dict["ann_id"].shape
        if self.dataset_name in ("ScanRefer", "Nr3D"):
            pred_aabb_score_masks = (output_dict["pred_aabb_scores"].argmax(dim=1)).reshape(
                shape=(batch_size, lang_chunk_size, -1)
            )
        elif self.dataset_name == "Multi3DRefer":
            if self.use_optimal_inference_thres:
                mask_dict = dict()
                for thres in self.inference_thres_list:
                    mask_dict[thres] = (
                        torch.sigmoid(output_dict["pred_aabb_scores"]) >= thres
                    ).reshape(shape=(batch_size, lang_chunk_size, -1)) # (bsize, csize, 80) with T/F
            else:
                pred_aabb_score_masks = (
                        torch.sigmoid(output_dict["pred_aabb_scores"]) >= self.hparams.cfg.model.inference.output_threshold
                ).reshape(shape=(batch_size, lang_chunk_size, -1)) # (bsize, csize, 80) with T/F
        else:
            raise NotImplementedError

        pred_results = {}
        if self.use_optimal_inference_thres:
            for thres, pred_aabb_score_masks in mask_dict.items():
                pred_results[thres]= dict()
                pred_aabb_score_masks_numpy = pred_aabb_score_masks.cpu().numpy()
                pred_aabb_bounds_numpy = output_dict["pred_aabb_min_max_bounds"].cpu().numpy()
                pred_aabb_score_numpy = output_dict["pred_aabb_scores"].reshape(batch_size, lang_chunk_size, -1).cpu().numpy()

                for i in range(batch_size):
                    for j in range(lang_chunk_size):
                        if self.hparams.cfg.model.network.use_nms and self.hparams.cfg.model.network.nms.mode=='end':
                            pred_aabbs = nms(pred_aabb_score_numpy[i][j][pred_aabb_score_masks_numpy[i][j]], pred_aabb_bounds_numpy[i][pred_aabb_score_masks_numpy[i][j]], self.hparams.cfg.model.network.nms.iou_threshold)
                            aabb_bounds = pred_aabbs + data_dict["scene_center_xyz"][i].cpu().numpy()
                        else:
                            pred_aabbs = output_dict["pred_aabb_min_max_bounds"][i][pred_aabb_score_masks[i, j]]
                            aabb_bounds = (pred_aabbs + data_dict["scene_center_xyz"][i]).cpu().numpy()
                        pred_results[thres][
                            (data_dict["scene_id"][i], data_dict["object_id"][i][j].item(),
                            data_dict["ann_id"][i][j].item())
                        ] = {
                            "aabb_bound": (aabb_bounds)
                        }
        else:
            pred_aabb_score_masks_numpy = pred_aabb_score_masks.cpu().numpy()
            pred_aabb_bounds_numpy = output_dict["pred_aabb_min_max_bounds"].cpu().numpy()
            pred_aabb_score_numpy = output_dict["pred_aabb_scores"].reshape(batch_size, lang_chunk_size, -1).cpu().numpy()

            for i in range(batch_size):
                for j in range(lang_chunk_size):
                    if self.hparams.cfg.model.network.use_nms and self.hparams.cfg.model.network.nms.mode=='end':
                        pred_aabbs = nms(pred_aabb_score_numpy[i][j][pred_aabb_score_masks_numpy[i][j]], pred_aabb_bounds_numpy[i][pred_aabb_score_masks_numpy[i][j]], self.hparams.cfg.model.network.nms.iou_threshold)
                        aabb_bounds = pred_aabbs + data_dict["scene_center_xyz"][i].cpu().numpy()
                    else:
                        pred_aabbs = output_dict["pred_aabb_min_max_bounds"][i][pred_aabb_score_masks[i, j]]
                        aabb_bounds = (pred_aabbs + data_dict["scene_center_xyz"][i]).cpu().numpy()
                    pred_results[
                        (data_dict["scene_id"][i], data_dict["object_id"][i][j].item(),
                        data_dict["ann_id"][i][j].item())
                    ] = {
                        "aabb_bound": (aabb_bounds)
                    }

        return pred_results
    

    def _parse_pred_results_test(self, data_dict, output_dict):
        batch_size, lang_chunk_size = data_dict["ann_id"].shape
        if self.dataset_name in ("ScanRefer", "Nr3D"):
            pred_aabb_score_masks = (output_dict["pred_aabb_scores"].argmax(dim=1)).reshape(
                shape=(batch_size, lang_chunk_size, -1)
            )
        elif self.dataset_name == "Multi3DRefer":
            if self.use_optimal_inference_thres:
                # print("Saved Optimal Inference Threshold:", self.optimal_inference_thres)
                pred_aabb_score_masks = (
                        torch.sigmoid(output_dict["pred_aabb_scores"]) >= self.optimal_inference_thres
                ).reshape(shape=(batch_size, lang_chunk_size, -1)) # (bsize, csize, 80) with T/F
            else:
                pred_aabb_score_masks = (
                        torch.sigmoid(output_dict["pred_aabb_scores"]) >= self.hparams.cfg.model.inference.output_threshold
                ).reshape(shape=(batch_size, lang_chunk_size, -1)) # (bsize, csize, 80) with T/F
        else:
            raise NotImplementedError

        pred_results = {}
        pred_aabb_score_masks_numpy = pred_aabb_score_masks.cpu().numpy()
        pred_aabb_bounds_numpy = output_dict["pred_aabb_min_max_bounds"].cpu().numpy()
        pred_aabb_score_numpy = output_dict["pred_aabb_scores"].reshape(batch_size, lang_chunk_size, -1).cpu().numpy()

        for i in range(batch_size):
            for j in range(lang_chunk_size):
                if self.hparams.cfg.model.network.use_nms and self.hparams.cfg.model.network.nms.mode=='end':
                    pred_aabbs = nms(pred_aabb_score_numpy[i][j][pred_aabb_score_masks_numpy[i][j]], pred_aabb_bounds_numpy[i][pred_aabb_score_masks_numpy[i][j]], self.hparams.cfg.model.network.nms.iou_threshold)
                    aabb_bounds = pred_aabbs + data_dict["scene_center_xyz"][i].cpu().numpy()
                else:
                    pred_aabbs = output_dict["pred_aabb_min_max_bounds"][i][pred_aabb_score_masks[i, j]]
                    aabb_bounds = (pred_aabbs + data_dict["scene_center_xyz"][i]).cpu().numpy()
                pred_results[
                    (data_dict["scene_id"][i], data_dict["object_id"][i][j].item(),
                    data_dict["ann_id"][i][j].item())
                ] = {
                    "aabb_bound": (aabb_bounds)
                }
        return pred_results
    
    
    def configure_optimizers(self):
        optimizers = []
        main_named_parameters = self.named_parameters()

        if self.hparams.cfg.model.network.learnable_camera_pose and self.hparams.cfg.model.network.use_2d_feature:
            main_named_parameters = [(n, p) for n, p in main_named_parameters if not n.startswith('object_renderer.')]

        main_parameters = [p for n, p in main_named_parameters]
        main_optimizer = hydra.utils.instantiate(self.hparams.cfg.model.optimizer, params=main_parameters)
        optimizers.append(main_optimizer)

        if self.hparams.cfg.model.network.learnable_camera_pose and self.hparams.cfg.model.network.use_2d_feature:
            pose_optimizer = hydra.utils.instantiate(self.hparams.cfg.model.pose_optimizer, params=self.object_renderer.parameters())
            optimizers.append(pose_optimizer)

        return optimizers