import lightning.pytorch as pl
import torch
import torch.nn as nn
from dlisa.model.cross_modal_module.attention import MultiHeadAttention, MultiHeadAttentionSpatial, FixedMultiHeadAttention

class MatchModule(pl.LightningModule):
    def __init__(self, feat_channel, input_channel, head, depth):
        super().__init__()
        self.depth = depth - 1
        self.features_concat = nn.Sequential(
            nn.Conv1d(input_channel, feat_channel, 1),
            nn.BatchNorm1d(feat_channel),
            nn.PReLU(feat_channel),
            nn.Conv1d(feat_channel, feat_channel, 1),
        )
        self.self_attn = nn.ModuleList(
            MultiHeadAttention(
                d_model=feat_channel,
                h=head,
                d_k=feat_channel // head,
                d_v=feat_channel // head,
                dropout=0.1
            ) for _ in range(depth)
        )
        self.cross_attn = nn.ModuleList(
            MultiHeadAttention(
                d_model=feat_channel,
                h=head,
                d_k=feat_channel // head,
                d_v=feat_channel // head,
                dropout=0.1
            ) for _ in range(depth)
        )
        self.match = nn.Sequential(
            nn.Conv1d(feat_channel, feat_channel, 1),
            nn.BatchNorm1d(feat_channel),
            nn.PReLU(),
            nn.Conv1d(feat_channel, feat_channel, 1),
            nn.BatchNorm1d(feat_channel),
            nn.PReLU(),
            nn.Conv1d(feat_channel, 1, 1)
        )

    def forward(self, data_dict, output_dict):
        batch_size, chunk_size = data_dict["ann_id"].shape[0:2]
        num_proposals = output_dict["pred_aabb_min_max_bounds"].shape[1]
        # attention weight
        attention_weights = self._calculate_spatial_weight(output_dict)
        aabb_features = self.features_concat(output_dict['aabb_features'].permute(0, 2, 1)).permute(0, 2, 1)
        output_dict["aabb_features_inter"] = aabb_features
        aabb_features = self.self_attn[0](
            aabb_features, aabb_features, aabb_features, attention_weights=attention_weights, way="add"
        )
        aabb_features = aabb_features.unsqueeze(dim=1).expand(-1, chunk_size, -1, -1).flatten(start_dim=0, end_dim=1)
        attention_weights = attention_weights.unsqueeze(dim=1).expand(-1, chunk_size, -1, -1, -1).reshape(
            batch_size * chunk_size, attention_weights.shape[1], num_proposals, num_proposals
        )
        aabb_features = self.cross_attn[0](
            aabb_features, output_dict["word_features"], output_dict["word_features"],
            attention_mask=data_dict["lang_attention_mask"]
        )
        for i in range(1, self.depth + 1):
            aabb_features = self.self_attn[i](
                aabb_features, aabb_features, aabb_features, attention_weights=attention_weights, way="add"
            )
            aabb_features = self.cross_attn[i](
                aabb_features, output_dict["word_features"], output_dict["word_features"],
                attention_mask=data_dict["lang_attention_mask"]
            )
        # match
        aabb_features = aabb_features.permute(0, 2, 1).contiguous()
        output_dict["pred_aabb_scores"] = self.match(aabb_features).flatten(start_dim=0, end_dim=1)

    def _calculate_spatial_weight(self, output_dict):
        """
        Reference: https://github.com/zlccccc/3DVG-Transformer
        """
        objects_center = output_dict["pred_aabb_min_max_bounds"].mean(dim=2)
        num_proposals = objects_center.shape[1]
        center_a = objects_center.unsqueeze(dim=1).repeat(1, num_proposals, 1, 1) # (bsize, 80, 80, 3)
        center_b = objects_center.unsqueeze(dim=2).repeat(1, 1, num_proposals, 1) # (bsize, 80, 80, 3)
        dist = (center_a - center_b).pow(2)
        dist = torch.sqrt(dist.sum(dim=-1)).unsqueeze(dim=1) # (bsize, 1, 80, 80)
        dist_weights = 1 / (dist + 1e-2)

        # mask placeholders
        tmp_unsqueezed = output_dict["proposal_masks_dense"].unsqueeze(-1) # (bsize, 1, 80)
        dist_weights *= (tmp_unsqueezed.transpose(1, 2) * tmp_unsqueezed).unsqueeze(dim=1) # (bsize, 1, 80, 80)

        dist_weights += torch.finfo(torch.float32).eps  # prevent zeros
        norm = dist_weights.sum(dim=2, keepdim=True) # (bsize, 1, 1, 80)
        dist_weights = dist_weights / norm
        zeros = torch.zeros_like(dist_weights) # (bsize, 1, 80, 80)
        dist_weights = torch.cat([dist_weights, -dist, zeros, zeros], dim=1).detach() # (bsize, 4, 80, 80)
        return dist_weights


class FixedMatchModule(MatchModule):
    def __init__(self, feat_channel, input_channel, head, depth):
        super(FixedMatchModule, self).__init__(feat_channel, input_channel, head, depth)

        self.head = head

        self.self_attn = nn.ModuleList(
            FixedMultiHeadAttention(
                d_model=feat_channel,
                h=head,
                d_k=feat_channel // head,
                d_v=feat_channel // head,
                dropout=0.1
            ) for _ in range(depth)
        )
        self.cross_attn = nn.ModuleList(
            FixedMultiHeadAttention(
                d_model=feat_channel,
                h=head,
                d_k=feat_channel // head,
                d_v=feat_channel // head,
                dropout=0.1
            ) for _ in range(depth)
        )

    def _calculate_spatial_weight_no_mask(self, output_dict):
        """
        Reference: https://github.com/zlccccc/3DVG-Transformer
        """
        objects_center = output_dict["pred_aabb_min_max_bounds"].mean(dim=2)
        num_proposals = objects_center.shape[1]
        center_a = objects_center.unsqueeze(dim=1).repeat(1, num_proposals, 1, 1) # (bsize, 80, 80, 3)
        center_b = objects_center.unsqueeze(dim=2).repeat(1, 1, num_proposals, 1) # (bsize, 80, 80, 3)
        dist = (center_a - center_b).pow(2)
        dist = torch.sqrt(dist.sum(dim=-1)).unsqueeze(dim=1) # (bsize, 1, 80, 80)
        dist_weights = 1 / (dist + 1e-2)

        dist_weights += torch.finfo(torch.float32).eps  # prevent zeros
        norm = dist_weights.sum(dim=2, keepdim=True) # (bsize, 1, 1, 80)
        dist_weights = dist_weights / norm
        zeros = torch.zeros_like(dist_weights) # (bsize, 1, 80, 80)
        dist_weights = torch.cat([dist_weights, -dist, zeros, zeros], dim=1).detach() # (bsize, 4, 80, 80)
        return dist_weights

    def forward(self, data_dict, output_dict):
        batch_size, chunk_size = data_dict["ann_id"].shape[0:2]
        num_proposals = output_dict["pred_aabb_min_max_bounds"].shape[1]

        # attention weight
        attention_weights = self._calculate_spatial_weight_no_mask(output_dict)

        # attention mask
        tmp_unsqueezed = output_dict["proposal_masks_dense"].unsqueeze(-1) # (bsize, 80, 1)
        attention_mask = tmp_unsqueezed.transpose(1, 2) * tmp_unsqueezed
        attention_mask = attention_mask.unsqueeze(dim=1).expand(-1, self.head, -1, -1) # (bsize, h, 80, 80)

        aabb_features = self.features_concat(output_dict['aabb_features'].permute(0, 2, 1)).permute(0, 2, 1)
        output_dict["aabb_features_inter"] = aabb_features

        aabb_features = aabb_features.unsqueeze(dim=1).expand(-1, chunk_size, -1, -1).flatten(start_dim=0, end_dim=1)

        attention_weights = attention_weights.unsqueeze(dim=1).expand(-1, chunk_size, -1, -1, -1).reshape(
            batch_size * chunk_size, attention_weights.shape[1], num_proposals, num_proposals
        )
    
        attention_mask = attention_mask.unsqueeze(dim=1).expand(-1, chunk_size, -1, -1, -1).reshape(
            batch_size * chunk_size, attention_mask.shape[1], num_proposals, num_proposals
        )

        for i in range(0, self.depth + 1):
            aabb_features = self.self_attn[i](
                aabb_features, aabb_features, aabb_features, attention_mask=attention_mask, attention_weights=attention_weights, way="add"
            )
            aabb_features = self.cross_attn[i](
                aabb_features, output_dict["word_features"], output_dict["word_features"],
                attention_mask=data_dict["lang_attention_mask"]
            )
        # match
        aabb_features = aabb_features.permute(0, 2, 1).contiguous()
        output_dict["pred_aabb_scores"] = self.match(aabb_features).flatten(start_dim=0, end_dim=1)


class SpatialMatchModule(MatchModule):
    def __init__(self, feat_channel, input_channel, head, depth, spatial_mode, spatial_way, attn_way, use_nid, d_threshold, score_mode):
        super(SpatialMatchModule, self).__init__(feat_channel, input_channel, head, depth)

        self.head = head
        self.spatial_mode = spatial_mode # ['abs', 'rel', 'none', 'dist_score']
        self.spatial_way = spatial_way # ['diff', 'tanh', 'dot', 'center_nerf', 'dist']
        self.attn_way = attn_way # ['add', 'balanced']
        self.use_nid = use_nid
        self.d_threshold = d_threshold
        self.score_mode = score_mode

        self.word_self_attn = FixedMultiHeadAttention(
                d_model=feat_channel,
                h=head,
                d_k=feat_channel // head,
                d_v=feat_channel // head,
                dropout=0.1
            )
        
        self.self_attn = nn.ModuleList(
            MultiHeadAttentionSpatial(
                d_model=feat_channel,
                h=head,
                d_k=feat_channel // head,
                d_v=feat_channel // head,
                dropout=0.1
            ) for _ in range(depth)
        )

        if spatial_mode == 'rel':
            self.spatial_self_attn = nn.ModuleList(
                FixedMultiHeadAttention(
                    d_model=feat_channel,
                    h=head,
                    d_k=feat_channel // head,
                    d_v=feat_channel // head,
                    dropout=0.1
                ) for _ in range(depth)
            )

            self.spatial_cross_attn = nn.ModuleList(
                FixedMultiHeadAttention(
                    d_model=feat_channel,
                    h=head,
                    d_k=feat_channel // head,
                    d_v=feat_channel // head,
                    dropout=0.1
                ) for _ in range(depth)
            )

            self.spatial_fc = nn.Linear(2 * feat_channel, feat_channel)

        
        if spatial_mode == 'dist_score':
            assert spatial_way == 'dist', print("Please recheck the spatial module setting!")
            assert attn_way == 'balanced', print("Please recheck the spatial module setting!")
            # Assuming sentence feature have feat_channel features
            # Need to be modified if use other model size
            if score_mode == 's':
                self.score_fc = nn.Linear(feat_channel, 1)
            elif score_mode == 'si' or score_mode == 'sj':
                self.score_fc = nn.Linear(2 * feat_channel, 1)
            elif score_mode == 'sij':
                self.score_fc = nn.Linear(3 * feat_channel, 1)

        if spatial_mode == 'abs':
            assert spatial_way in ['center_nerf'], print("Spatial way inconsistent with mode!")
            self.abs_fc = nn.Linear(2 * feat_channel, feat_channel)

        if spatial_way in ['diff', 'tanh']:
            self.score_fc = nn.Linear(feat_channel, 1)


    def forward(self, data_dict, output_dict):
        batch_size, chunk_size = data_dict["ann_id"].shape[0:2]
        num_proposals = output_dict["pred_aabb_min_max_bounds"].shape[1]
        word_features = output_dict["word_features"]
        aabb_features = output_dict['aabb_features']
        score = None

        # word features
        word_features = self.word_self_attn(word_features, word_features, word_features)

        # attention weight
        if self.spatial_way == 'dist':
            attention_weights = self._calculate_spatial_weight_no_mask(output_dict)
            attention_weights = attention_weights.unsqueeze(dim=1).expand(-1, chunk_size, -1, -1, -1).reshape(
                batch_size * chunk_size, attention_weights.shape[1], num_proposals, num_proposals
            )
            attention_weights = torch.repeat_interleave(attention_weights, repeats=self.head//4, dim=1)
        else:
            attention_weights = None

        # attention mask
        tmp_unsqueezed = output_dict["proposal_masks_dense"].unsqueeze(-1) # (bsize, 80, 1)
        attention_mask = tmp_unsqueezed.transpose(1, 2) * tmp_unsqueezed
        attention_mask = attention_mask.unsqueeze(dim=1).expand(-1, self.head, -1, -1) # (bsize, h, 80, 80)

        aabb_features = self.features_concat(aabb_features.permute(0, 2, 1)).permute(0, 2, 1)
        output_dict["aabb_features_inter"] = aabb_features


        if self.spatial_mode == 'dist_score':
            sentence_feature = output_dict['sentence_features'] # (b*c, 128)
            sentence_feature = sentence_feature.unsqueeze(dim=1).unsqueeze(dim=1).expand(-1, num_proposals, num_proposals, -1)
            concate_feature = sentence_feature # (bc, n, n, 128)

            obj_q = aabb_features.unsqueeze(dim=2).expand(-1, -1, num_proposals, -1)
            obj_k = aabb_features.unsqueeze(dim=1).expand(-1, num_proposals, -1, -1)

            if self.score_mode == 's':
                pass
            elif self.score_mode == 'si': 
                concate_feature = torch.cat((concate_feature, obj_q), dim=-1)
            elif self.score_mode == 'sj':
                concate_feature = torch.cat((concate_feature, obj_k), dim=-1)
            elif self.score_mode == 'sij':
                concate_feature = torch.cat((concate_feature, obj_q, obj_k), dim=-1)
            
            score = torch.sigmoid(self.score_fc(concate_feature)).squeeze(-1)
            score = score.unsqueeze(dim=1).expand(-1, self.head, -1, -1)

        for i in range(0, self.depth + 1):

            aabb_features = self.self_attn[i](
                aabb_features, aabb_features, aabb_features, attention_mask=attention_mask, attention_weights=attention_weights, way=self.attn_way, score=score
            )

            aabb_features = self.cross_attn[i](
                aabb_features, word_features, word_features,
                attention_mask=data_dict["lang_attention_mask"]
            )
        
        # match
        aabb_features = aabb_features.permute(0, 2, 1).contiguous()
        output_dict["pred_aabb_scores"] = self.match(aabb_features).flatten(start_dim=0, end_dim=1)


    def _calculate_spatial_weight_no_mask(self, output_dict):
            """
            Reference: https://github.com/zlccccc/3DVG-Transformer
            """
            objects_center = output_dict["pred_aabb_min_max_bounds"].mean(dim=2)
            num_proposals = objects_center.shape[1]
            center_a = objects_center.unsqueeze(dim=1).repeat(1, num_proposals, 1, 1) # (bsize, 80, 80, 3)
            center_b = objects_center.unsqueeze(dim=2).repeat(1, 1, num_proposals, 1) # (bsize, 80, 80, 3)
            dist = (center_a - center_b).pow(2)
            dist = torch.sqrt(dist.sum(dim=-1)).unsqueeze(dim=1) # (bsize, 1, 80, 80)
            dist_weights = 1 / (dist + 1e-2)

            dist_weights += torch.finfo(torch.float32).eps  # prevent zeros
            norm = dist_weights.sum(dim=2, keepdim=True) # (bsize, 1, 1, 80)
            dist_weights = dist_weights / norm
            zeros = torch.zeros_like(dist_weights) # (bsize, 1, 80, 80)
            dist_weights = torch.cat([dist_weights, -dist, zeros, zeros], dim=1).detach() # (bsize, 4, 80, 80)
            return dist_weights
    