import torch
import numpy as np
import torch.nn.functional as F

from pytorch3d.ops import estimate_pointcloud_normals
from pytorch3d.loss import chamfer_distance
from pytorch3d.ops import knn_points, knn_gather

from lidarnerf.dataset.base_dataset import custom_meshgrid
from lidarnerf.convert import pano_to_lidar
def chamfer_distance_low_capacity(keypoints1,keypoints2):
    '''
    kp1:B N 3
    kp2:B N 3
    '''
    dist1, idx1,_= knn_points(keypoints1, keypoints2, K=1, return_nn=False)#dist:BM1,idx:BM1
    dist2, idx2,_= knn_points(keypoints2, keypoints1, K=1, return_nn=False)
    dist=dist1.mean()+dist2.mean()
    return dist,idx1,idx2
def chamfer_based_norm_loss_low_capacity(keypoints1,keypoints2,idx1,idx2):
    norm1=estimate_pointcloud_normals(keypoints1,neighborhood_size=30)
    norm2=estimate_pointcloud_normals(keypoints2,neighborhood_size=30)
    nearst_norm1=knn_gather(norm2,idx1)
    nearst_norm2=knn_gather(norm1,idx2)
    nearst_norm1=torch.squeeze(nearst_norm1)
    nearst_norm2=torch.squeeze(nearst_norm2)

    n1=torch.norm(norm1-nearst_norm1,dim=-1)
    n2=torch.norm(norm2-nearst_norm2,dim=-1)
    n=(n1**2).mean()+(n2**2).mean()
    return n

def chamfer_distance(keypoints1,keypoints2):
    """
    Calculate probabilistic chamfer distance between keypoints1 and keypoints2
    Input:
        keypoints1: [B,M,3]
        keypoints2: [B,N,3]
    """
    keypoints1 = keypoints1.permute(0,2,1).contiguous() #b3M
    keypoints2 = keypoints2.permute(0,2,1).contiguous() #b3n
    B, M = keypoints1.size()[0], keypoints1.size()[2]
    N = keypoints2.size()[2]

    keypoints1_expanded = keypoints1.unsqueeze(3).expand(B,3,M,N)
    keypoints2_expanded = keypoints2.unsqueeze(2).expand(B,3,M,N)

    diff = torch.norm(keypoints1_expanded-keypoints2_expanded, dim=1, keepdim=False)
    min_dist_forward, idx_forward = torch.min(diff, dim=2, keepdim=False)
    forward_loss = (min_dist_forward**2).mean()
    min_dist_backward, idx_backward = torch.min(diff, dim=1, keepdim=False)
    backward_loss = (min_dist_backward**2).mean()
    loss = forward_loss + backward_loss

    return loss,idx_forward,idx_backward

def chamfer_based_norm_loss(keypoints1,keypoints2,idx_forward,idx_backward):
    """
    Input:
        keypoints1: [B,M,3]
        keypoints2: [B,N,3]
    Actually the keypoints here is not xyz of a point, but the norm of a point 
    """
    keypoints1=estimate_pointcloud_normals(keypoints1,neighborhood_size=30)
    keypoints2=estimate_pointcloud_normals(keypoints2,neighborhood_size=30)

    keypoints1 = keypoints1.permute(0,2,1).contiguous() #b3M
    keypoints2 = keypoints2.permute(0,2,1).contiguous() #b3n
    B, M = keypoints1.size()[0], keypoints1.size()[2]
    N = keypoints2.size()[2]
    keypoints1_expanded = keypoints1.unsqueeze(3).expand(B,3,M,N)
    keypoints2_expanded = keypoints2.unsqueeze(2).expand(B,3,M,N)
    diff = torch.norm(keypoints1_expanded-keypoints2_expanded, dim=1, keepdim=False)
    #
    norm_forward=diff[torch.arange(B).unsqueeze(1).expand(-1, M),torch.arange(M).unsqueeze(0).expand(B, -1),idx_forward]
    norm_backward=diff[torch.arange(B).unsqueeze(1).expand(-1, N),idx_backward,torch.arange(N).unsqueeze(0).expand(B, -1)]
    forward_loss = norm_forward.mean()
    backward_loss = norm_backward.mean()
    loss = forward_loss + backward_loss
    return loss

def result_process(self,data,outputs_lidar):
    image_lidar_sample_rays=outputs_lidar["image_lidar_sample_rays"]#BN3(1 32*1080 3)
    gt_raydrop = image_lidar_sample_rays[:, :, 0] #BN
    gt_intensity = image_lidar_sample_rays[:, :, 1] * gt_raydrop #BN
    gt_depth = image_lidar_sample_rays[:, :, 2] * gt_raydrop #BN
    pred_raydrop = outputs_lidar["intensity"][:, :, 0] #BN
    pred_intensity = outputs_lidar["intensity"][:, :, 1] * gt_raydrop #BN
    pred_depth = outputs_lidar["depth_lidar"] * gt_raydrop #BN
    lidar_loss = (
        self.opt.alpha_d * self.criterion["depth"](pred_depth, gt_depth)
        + self.opt.alpha_r * self.criterion["raydrop"](pred_raydrop, gt_raydrop)
        + self.opt.alpha_i * self.criterion["intensity"](pred_intensity, gt_intensity)
        # + 0.01 * outputs_lidar["loss_dist"] #TODO
    )
    
    idx=data["index"]
    rangemap=data["image"] #1 32 1080 3
    depth=rangemap[0,:,:,2] #1 32 1080
    depth=depth.cpu().numpy()
    pcd = pano_to_lidar(depth, [10,40])/self.opt.scale #(N, 3), float32, in lidar frame.
    pcd_on_cpu=pcd[None, ...] #1 N 3
    pcd_on_cpu = torch.FloatTensor(pcd_on_cpu)
    device = torch.device(self.opt.device)
    pcd1_on_gpu = pcd_on_cpu.to(device)

    pred_depth=pred_depth.view(depth.shape)
    pred_depth=depth.cpu().numpy()
    pcd2 = pano_to_lidar(pred_depth, [10,40])/self.opt.scale #(N, 3), float32, in lidar frame.
    pcd2_on_cpu=pcd2[None, ...] #1 N 3
    pcd2_on_cpu = torch.FloatTensor(pcd2_on_cpu)
    device = torch.device(self.opt.device)
    pcd2_on_gpu = pcd2_on_cpu.to(device)
