import torch
import numpy as np


def get_batch_aabb_pair_ious(batch_boxes_1_bound, batch_boxes_2_bound):
    box_1_x_min, box_1_y_min, box_1_z_min = torch.tensor_split(batch_boxes_1_bound[:, 0], 3, dim=1)
    box_1_x_max, box_1_y_max, box_1_z_max = torch.tensor_split(batch_boxes_1_bound[:, 1], 3, dim=1)

    box_2_x_min, box_2_y_min, box_2_z_min = torch.tensor_split(batch_boxes_2_bound[:, 0], 3, dim=1)
    box_2_x_max, box_2_y_max, box_2_z_max = torch.tensor_split(batch_boxes_2_bound[:, 1], 3, dim=1)

    x_a = torch.maximum(box_1_x_min, box_2_x_min)
    y_a = torch.maximum(box_1_y_min, box_2_y_min)
    z_a = torch.maximum(box_1_z_min, box_2_z_min)
    x_b = torch.minimum(box_1_x_max, box_2_x_max)
    y_b = torch.minimum(box_1_y_max, box_2_y_max)
    z_b = torch.minimum(box_1_z_max, box_2_z_max)

    zero_tensor = torch.zeros_like(x_a)
    intersection_volume = torch.maximum((x_b - x_a), zero_tensor) * torch.maximum((y_b - y_a), zero_tensor) * \
                          torch.maximum((z_b - z_a), zero_tensor)
    box_1_volume = (box_1_x_max - box_1_x_min) * (box_1_y_max - box_1_y_min) * (box_1_z_max - box_1_z_min)
    box_2_volume = (box_2_x_max - box_2_x_min) * (box_2_y_max - box_2_y_min) * (box_2_z_max - box_2_z_min)
    iou = intersection_volume / (box_1_volume + box_2_volume - intersection_volume + torch.finfo(torch.float32).eps)
    return iou.flatten()


def get_batch_aabb_pair_ious_np(box1, box2):
    """Calculate the Intersection over Union (IoU) of two 3D AABB boxes."""
    min1, max1 = box1
    min2, max2 = box2

    # Find the intersection box
    inter_min = np.maximum(min1, min2)
    inter_max = np.minimum(max1, max2)
    inter_dim = np.maximum(inter_max - inter_min, 0)

    # Calculate intersection and union volumes
    intersection = np.prod(inter_dim)
    volume1 = np.prod(max1 - min1)
    volume2 = np.prod(max2 - min2)
    union = volume1 + volume2 - intersection

    # Compute IoU
    return intersection / union if union != 0 else 0


def nms(scores, boxes, iou_threshold):
    if boxes is None: return None
    C = boxes.shape[0]

    # Sort the boxes by scores in descending order
    indices = np.argsort(scores)[::-1]

    # Boolean array to identify whether a box is suppressed
    is_suppressed = np.zeros(C, dtype=bool)

    for i in range(C):
        if is_suppressed[indices[i]]:
            continue

        for j in range(i + 1, C):
            if get_batch_aabb_pair_ious_np(boxes[indices[i]], boxes[indices[j]]) > iou_threshold:
                is_suppressed[indices[j]] = True

    # Filter out suppressed boxes
    selected_indices = indices[~is_suppressed[indices]]
    return boxes[selected_indices]


def iou_3d(boxes1, boxes2):
    """
    Compute the Intersection over Union (IoU) of two sets of 3D boxes.
    Boxes are in shape (N, 2, 3) where 2 represents [min_corner, max_corner].

    Args:
        boxes1 (Tensor): shape (N, 2, 3), N boxes
        boxes2 (Tensor): shape (M, 2, 3), M boxes

    Returns:
        Tensor: IoU matrix of shape (N, M)
    """
    N = boxes1.size(0)
    M = boxes2.size(0)

    # Extract min and max corners
    min_corner1 = boxes1[:, 0, :]
    max_corner1 = boxes1[:, 1, :]
    min_corner2 = boxes2[:, 0, :]
    max_corner2 = boxes2[:, 1, :]

    # Expand dimensions to compute pairwise min/max
    min_corner1 = min_corner1.unsqueeze(1).expand(N, M, 3)
    max_corner1 = max_corner1.unsqueeze(1).expand(N, M, 3)
    min_corner2 = min_corner2.unsqueeze(0).expand(N, M, 3)
    max_corner2 = max_corner2.unsqueeze(0).expand(N, M, 3)

    # Compute intersection
    inter_min = torch.max(min_corner1, min_corner2)
    inter_max = torch.min(max_corner1, max_corner2)
    inter_dims = torch.clamp(inter_max - inter_min, min=0)
    inter_vol = inter_dims.prod(2)

    # Compute volumes
    vol1 = (max_corner1 - min_corner1).prod(2)
    vol2 = (max_corner2 - min_corner2).prod(2)
    
    # Compute IoU
    iou = inter_vol / (vol1 + vol2 - inter_vol)
    return iou


def batched_nms_3d(boxes, scores, proposal_batch_id, iou_threshold):
    """
    Perform batch-wise Non-Maximum Suppression (NMS) on 3D bounding boxes.

    Args:
        boxes (Tensor): shape (N, 2, 3), [min_corner, max_corner] for each box
        scores (Tensor): shape (N,), confidence scores for each box
        proposal_batch_id (Tensor): shape (N,), indicating the batch ID for each box
        iou_threshold (float): IoU threshold for NMS

    Returns:
        Tensor: indices of boxes kept after NMS
    """
    keep_indices = []
    for batch_id in proposal_batch_id.unique():
        # Filter boxes and scores by batch
        batch_mask = proposal_batch_id == batch_id
        boxes_batch = boxes[batch_mask]
        scores_batch = scores[batch_mask]

        # Sort boxes by scores
        _, sorted_indices = scores_batch.sort(descending=True)
        selected_indices = []

        while len(sorted_indices) > 0:
            # Select the box with highest score
            current_idx = sorted_indices[0]
            selected_indices.append(current_idx.item())
            if len(sorted_indices) == 1:
                break
            
            # Compute IoU with the rest
            current_box = boxes_batch[current_idx].unsqueeze(0)
            rest_boxes = boxes_batch[sorted_indices[1:]]
            ious = iou_3d(current_box, rest_boxes).squeeze(0)

            # Keep boxes with IoU less than threshold
            low_iou_mask = ious < iou_threshold
            sorted_indices = sorted_indices[1:][low_iou_mask]

        # Convert selected indices for the current batch to global indices
        selected_global_indices = torch.nonzero(batch_mask).flatten()[selected_indices]
        keep_indices.extend(selected_global_indices.tolist())

    return torch.tensor(keep_indices, device=boxes.device, dtype=torch.long)