import pdb, sys
sys.path.append('utils/point_cloud_query')
sys.path.append('utils/pointnet2')
from pointnet2_modules import PointnetFPModule,PointnetSAModule

import torch
import torch.nn as nn
from torch.distributions import MultivariateNormal



def get_activation(activation):
    if activation.lower() == 'relu':
        return nn.ReLU(inplace=True)
    elif activation.lower() == 'leakyrelu':
        return nn.LeakyReLU(inplace=True)
    elif activation.lower() == 'sigmoid':
        return nn.Sigmoid()
    elif activation.lower() == 'softplus':
        return nn.Softplus()
    elif activation.lower() == 'gelu':
        return nn.GELU()
    elif activation.lower() == 'selu':
        return nn.SELU(inplace=True)
    elif activation.lower() == 'mish':
        return nn.Mish(inplace=True)
    else:
        raise Exception("Activation Function Error")


def get_norm(norm, width):
    if norm == 'LN':
        return nn.LayerNorm(width)
    elif norm == 'BN':
        return nn.BatchNorm1d(width)
    elif norm == 'IN':
        return nn.InstanceNorm1d(width)
    elif norm == 'GN':
        return nn.GroupNorm(width)
    else:
        raise Exception("Normalization Layer Error")

class NeuralPCI_Layer(torch.nn.Module):
    def __init__(self, 
                 dim_in,
                 dim_out,
                 norm=None, 
                 act_fn=None
                 ):
        super().__init__()
        layer_list = []
        layer_list.append(nn.Linear(dim_in, dim_out))
        if norm:
            layer_list.append(get_norm(norm, dim_out))
        if act_fn:
            layer_list.append(get_activation(act_fn))
        self.layer = nn.Sequential(*layer_list)

    def forward(self, x):
        x = self.layer(x)
        return x


class NeuralPCI_Block(torch.nn.Module):
    def __init__(self, 
                 depth, 
                 width,
                 norm=None, 
                 act_fn=None
                 ):
        super().__init__()
        layer_list = []
        for _ in range(depth):
            layer_list.append(nn.Linear(width, width))
            if norm:
                layer_list.append(get_norm(norm, width))
            if act_fn:
                layer_list.append(get_activation(act_fn))
        self.mlp = nn.Sequential(*layer_list)

    def forward(self, x):
        x = self.mlp(x)
        return x
    
class DeformationFieldNetwork(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, depth):
        super(DeformationFieldNetwork, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            NeuralPCI_Block(depth, hidden_dim, norm='LN', act_fn='leakyrelu'),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, gaussian_params, time_pred):
        time_pred = time_pred.repeat(1, gaussian_params.size(1), 1)
        
        input_tensor = torch.cat([gaussian_params, time_pred], dim=-1)
        deformation = self.mlp(input_tensor)
        return deformation

class TemporalPredictor(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(TemporalPredictor, self).__init__()
        self.gru = nn.GRU(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        output, _ = self.gru(x)
        output = self.fc(output[:, -1, :])
        return output

class GaussianPointCloudPrediction:
    def __init__(self, deformation_field_network, num_gaussians):
    # def __init__(self, deformation_field_network, temporal_predictor, num_gaussians):
        self.deformation_field_network = deformation_field_network
        # self.temporal_predictor = temporal_predictor
        self.num_gaussians = num_gaussians
        self.log_diag = nn.Parameter(torch.randn(num_gaussians, 3)).cuda()
        self.off_diag = nn.Parameter(torch.randn(num_gaussians, 3 * (3 - 1) // 2)).cuda()

    def predict(self, gaussians_mean, gaussians_cov, gaussians_feature, time_pred):
        gaussians_feature = gaussians_feature.unsqueeze(0)

        gaussian_params = torch.cat([gaussians_mean, gaussians_cov.flatten(start_dim=-2), gaussians_feature], dim=-1)

        deformation = self.deformation_field_network(gaussian_params, time_pred)

        predicted_mean = gaussians_mean + deformation[..., :3]
        predicted_cov = gaussians_cov + deformation[..., 3:12].view_as(gaussians_cov)

        log_diag_pred = deformation[..., 3:6].squeeze(0)
        off_diag_pred = deformation[..., 6:9].squeeze(0)
        
        chol = torch.zeros(self.num_gaussians, 3, 3, device=gaussians_mean.device)
        chol[:, range(3), range(3)] = torch.exp(self.log_diag + log_diag_pred)
        
        tril_indices = torch.tril_indices(row=3, col=3, offset=-1)
        chol[:, tril_indices[0], tril_indices[1]] = self.off_diag + off_diag_pred

        predicted_cov = chol @ chol.transpose(-1, -2)

        predicted_feature = gaussians_feature + deformation[..., 12:]

        # gaussian_sequence = torch.stack([gaussian_params, torch.cat([predicted_mean, predicted_cov.flatten(start_dim=-2), predicted_feature], dim=-1)], dim=1)
        # gaussian_sequence = gaussian_sequence.view(1, 2, -1)

        # motion_prediction = self.temporal_predictor(gaussian_sequence).view(gaussian_params.shape[1], gaussian_params.shape[2])

        return predicted_mean, predicted_cov, predicted_feature.squeeze(0) #, motion_prediction

class SpatioTemporalSA(nn.Module):
    def __init__(self, in_channels, time_dim=1):
        super(SpatioTemporalSA, self).__init__()
        self.time_dim = time_dim
        self.time_encoding = nn.Linear(time_dim, in_channels//32)
        self.mlp = NeuralPCI_Block(depth=2, width=in_channels, norm=None, act_fn='relu')
        self.gru = nn.GRU(input_size=in_channels+3+in_channels//32, hidden_size=in_channels, num_layers=1)

    def forward(self, xyz, points, time):
        N, _ = xyz.shape
        
        time_encoded = self.time_encoding(time)
        new_points = torch.cat([xyz, points, time_encoded], dim=-1)
        
        new_points = new_points.unsqueeze(1)  # [N, 1, C+3]
        new_points, _ = self.gru(new_points)  
        new_points = self.mlp(new_points.squeeze(1)) + new_points.squeeze(1)
        
        new_xyz = xyz
        return new_xyz, new_points # N*C

class LogCholeskyLayer(nn.Module):
    def __init__(self, num_gaussians):  
        super(LogCholeskyLayer, self).__init__()
        self.log_diag = nn.Parameter(torch.randn(num_gaussians, 3))
        self.off_diag = nn.Parameter(torch.randn(num_gaussians, 3 * (3 - 1) // 2))
        self.num_gaussians = num_gaussians
    
    def forward(self):
        chol = torch.zeros(self.num_gaussians, 3, 3, device=self.log_diag.device)
        chol[:, range(3), range(3)] = torch.exp(self.log_diag)  
        
        tril_indices = torch.tril_indices(row=3, col=3, offset=-1)
        chol[:, tril_indices[0], tril_indices[1]] = self.off_diag
        
        cov = chol @ chol.transpose(-1, -2)  
        return cov

class NeuralPCI(torch.nn.Module):
    def __init__(self, args=None):
        super().__init__()
        self.args = args
        dim_pc = args.dim_pc
        dim_time = args.dim_time
        layer_width = args.layer_width 
        act_fn = args.act_fn
        norm = args.norm
        depth_encode = args.depth_encode
        depth_pred = args.depth_pred
        pe_mul = args.pe_mul

        if args.use_rrf:
            dim_rrf = args.dim_rrf
            self.transform = 0.1 * torch.normal(0, 1, size=[dim_pc, dim_rrf]).cuda()
        else:
            dim_rrf = dim_pc

        # input layer
        self.layer_input = NeuralPCI_Layer(dim_in = (dim_rrf + dim_time) * pe_mul, 
                                           dim_out = layer_width, 
                                           norm = norm,
                                           act_fn = act_fn
                                           )
        self.hidden_encode = NeuralPCI_Block(depth = depth_encode, 
                                             width = layer_width, 
                                             norm = norm,
                                             act_fn = act_fn
                                             )
        self.args_n_gaussians = args.n_gaussians
        layer_width_gs = layer_width
        self.num_points = args.num_points
        self.deformation_field_network = DeformationFieldNetwork(layer_width+12+1, layer_width_gs+12, layer_width_gs+12, depth_encode)
        # self.temporal_predictor = TemporalPredictor((layer_width_gs+12)*self.args_n_gaussians, (layer_width_gs+12)*self.args_n_gaussians, (layer_width_gs+12)*self.args_n_gaussians)
        self.gaussian_point_cloud_prediction = GaussianPointCloudPrediction(self.deformation_field_network, args.n_gaussians)
        # if args.Gaussians4D:
            # self.st_sa = SpatioTemporalSA(in_channels=layer_width, time_dim=1)
            # # self.gaussian_pc = GaussianPointCloud(layer_width + 3, args.n_gaussians)
            # self.deform_field = DeformationField(layer_width + 3*args.n_gaussians, 3)
            # self.num_points_per_gaussian = (args.num_points // 16)//args.n_gaussians

        # insert interpolation time
        self.layer_time = NeuralPCI_Layer(dim_in = layer_width + 3 + dim_time * pe_mul, 
                                          dim_out = layer_width, 
                                          norm = norm,
                                          act_fn = act_fn
                                          )

        # hidden layers
        self.hidden_pred = NeuralPCI_Block(depth = depth_pred, 
                                           width = layer_width, 
                                           norm = norm,
                                           act_fn = act_fn
                                           )

        # output layer
        self.layer_output = NeuralPCI_Layer(dim_in = layer_width, 
                                          dim_out = dim_pc, 
                                          norm = norm,
                                          act_fn = None
                                          )
        
        # zero init for last layer
        if args.zero_init:
            for m in self.layer_output.layer:
                if isinstance(m, nn.Linear):
                    # torch.nn.init.normal_(m.weight.data, 0, 0.01)
                    m.weight.data.zero_()
                    m.bias.data.zero_()
        self.fc = nn.Linear(layer_width*2, layer_width)
    def posenc(self, x):
        """
        sinusoidal positional encoding : N ——> 3 * N
        [x] ——> [x, sin(x), cos(x)]
        """
        sinx = torch.sin(x)
        cosx = torch.cos(x)
        x = torch.cat((x, sinx, cosx), dim=1)
        return x

    def forward(self, pc_current, time_current, time_pred, train=True):
        """
        pc_current: tensor, [N, 3]
        time_current: float, [1]
        time_pred: float, [1]
        output: tensor, [N, 3]
        """
        time_pred_gs = torch.tensor(time_pred).repeat(1, 1).cuda().float().detach()
        time_current = torch.tensor(time_current).repeat(pc_current.shape[0], 1).cuda().float().detach()
        time_pred = torch.tensor(time_pred).repeat(pc_current.shape[0], 1).cuda().float().detach()
        
        
        if self.args.use_rrf:
            pc_current = torch.matmul(2. * torch.pi * pc_current, self.transform)

        x = torch.cat((pc_current, time_current), dim=1)
        x = self.posenc(x)
        x = self.layer_input(x)

        x = self.hidden_encode(x)

        means, covs, gaussians_feature = self.gaussian_point_cloud(pc_current, x, self.args_n_gaussians)
        predicted_mean, predicted_cov, predicted_feature = self.gaussian_point_cloud_prediction.predict(means, covs, gaussians_feature, time_pred_gs)
        gaussian_feat_proj = predicted_feature.unsqueeze(0).unsqueeze(2).expand(-1, -1, self.num_points, -1)
        point_feat_exp = x.unsqueeze(0).unsqueeze(1).expand(-1, self.args_n_gaussians, -1, -1)
        fused_feat = torch.cat([gaussian_feat_proj, point_feat_exp], dim=-1)  # (B, M, N, C)
        pooled_feat = torch.max(fused_feat, dim=1)[0]  # (B, M, C)
        pooled_feat_fc = self.fc(pooled_feat).squeeze(0)  # N, C')
        
        # pc_interp,features_interp = self.render_point_cloud(pc_current, predicted_mean.squeeze(0), predicted_cov, predicted_feature, self.num_points//self.args_n_gaussians)

        time_pred = self.posenc(time_pred)
        x = torch.cat((pooled_feat_fc, x, time_pred), dim=1)
        # x = torch.cat((pc_interp, features_interp, time_pred), dim=1)

        x = self.layer_time(x)
        x = self.hidden_pred(x)

        x = self.layer_output(x)
        # pdb.set_trace()
        return x

    def gaussian_point_cloud(self, pc_current, features, num_gaussians, num_iters=10, epsilon=1e-5):

        N, _ = pc_current.shape
        _, C = features.shape
        
        cluster_centers = pc_current[torch.randperm(N)[:num_gaussians]]

        for _ in range(num_iters):
            distances = torch.cdist(pc_current, cluster_centers)
            probs = torch.softmax(-distances**2, dim=1)
            cluster_centers = torch.matmul(probs.t(), pc_current) / (probs.sum(dim=0).unsqueeze(1) + epsilon)
        gaussians_mean = cluster_centers
        pc_centered = pc_current.unsqueeze(1) - gaussians_mean.unsqueeze(0)
        pc_centered_squared = torch.matmul(pc_centered.unsqueeze(-1), pc_centered.unsqueeze(-2))
        gaussians_cov = torch.sum(pc_centered_squared * probs.unsqueeze(-1).unsqueeze(-1), dim=0) / (probs.sum(dim=0).unsqueeze(-1).unsqueeze(-1) + epsilon)
        gaussians_cov += torch.eye(3, device=pc_current.device) * epsilon
        gaussians_feature = torch.matmul(probs.t(), features) / (probs.sum(dim=0).unsqueeze(1) + epsilon) # (M, C)

        gaussians_mean = gaussians_mean.unsqueeze(0)  # (1, M, 3)
        gaussians_cov = gaussians_cov.unsqueeze(0)  # (1, M, 3, 3)

        return gaussians_mean, gaussians_cov, gaussians_feature
    
    def render_point_cloud(self, pc_current, predicted_mean, predicted_cov, predicted_feature, n_samples_per_gaussian):
        n_gaussians, feature_dim = predicted_feature.shape
        device = predicted_cov.device
        total_samples = n_samples_per_gaussian * n_gaussians

        point_cloud = torch.empty(size=(total_samples, 3), dtype=torch.float32, device=device)
        features = torch.empty(size=(total_samples, feature_dim), dtype=torch.float32, device=device)

        for i in range(n_gaussians):
            mean = predicted_mean[i].unsqueeze(1) 
            cov = predicted_cov[i]
            feature = predicted_feature[i]

            L = torch.linalg.cholesky(cov)
            epsilon = torch.randn(3, n_samples_per_gaussian, device=device)  
            samples = mean + L @ epsilon  
            start_index = i * n_samples_per_gaussian
            end_index = start_index + n_samples_per_gaussian

            point_cloud[start_index:end_index] = samples.T
            features[start_index:end_index] = feature.unsqueeze(0).expand(n_samples_per_gaussian, -1)

        point_cloud = point_cloud.detach().requires_grad_(True)
        features = features.detach().requires_grad_(True)

        return point_cloud, features