import argparse
import copy
import os
import random
import time
import json
from builtins import print
from pathlib import Path
import math
import roma
import pickle

import imageio
import mmcv
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm, trange

from lib import utils, temporalpoints, tineuvox
from lib.load_data import load_data
# from lib import tineuvox2 as tineuvox

from torch_efficient_distloss import flatten_eff_distloss
from skeletonizer import create_skeleton

import tensorboardX as tbx
import torchvision

from skimage.morphology import remove_small_holes
from cc3d import largest_k


def config_parser():
    '''Define command line arguments
    '''
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--config', required=True,
                        help='config file path')
    parser.add_argument("--seed", type=int, default=0,
                        help='Random seed')
    parser.add_argument("--ft_path", type=str, default='',
                        help='specific weights npy file to reload for coarse network')
    # testing options
    parser.add_argument("--render_only", action='store_true',
                        help='do not optimize, reload weights and render out render_poses path')
    parser.add_argument("--render_test", action='store_true')
    parser.add_argument("--test", action='store_true')
    parser.add_argument("--overwrite_cache", action='store_true')
    parser.add_argument("--use_cache", action='store_true')
    parser.add_argument("--render_train", action='store_true')
    parser.add_argument("--render_video", action='store_true')
    parser.add_argument("--load_test_val", action='store_true')
    parser.add_argument("--joint_placement", action='store_true')
    parser.add_argument("--visualise_weights", action='store_true')
    parser.add_argument("--visualise_canonical", action='store_true')
    parser.add_argument("--repose_pcd", action='store_true')
    parser.add_argument("--first_stage_only", action='store_true')
    parser.add_argument("--second_stage_only", action='store_true')
    parser.add_argument("--debug_bone_merging", action='store_true')
    parser.add_argument("--visualise_warp", action='store_true')
    parser.add_argument("--render_pcd_direct", action='store_true')
    parser.add_argument("--render_pcd", action='store_true')
    parser.add_argument("--render_video_factor", type=int, default=0,
                        help='downsampling factor to speed up rendering, set 4 or 8 for fast preview')
    parser.add_argument("--eval_ssim", action='store_true')
    parser.add_argument("--eval_lpips_alex", action='store_true')
    parser.add_argument("--eval_lpips_vgg", action='store_true')
    parser.add_argument("--eval_psnr", action='store_true')
    parser.add_argument("--benchmark", action='store_true')
    parser.add_argument("--degree_threshold", type=float, default=0.)
    parser.add_argument("--ablation_tag", type=str)
    parser.add_argument("--skip_load_images", action='store_true')

    # logging/saving options
    parser.add_argument("--i_print",   type=int, default=1000,
                        help='frequency of console printout and metric loggin')
    parser.add_argument("--i_save",   type=int, default=5000)
    parser.add_argument("--fre_test", type=int, default=500000,
                        help='frequency of test')
    parser.add_argument("--basedir_append_suffix", type=str, default='',)
    parser.add_argument("--step_to_half", type=int, default=100000,
                        help='The iteration when fp32 becomes fp16')
    parser.add_argument("--export_bbox_and_cams_only", type=str, default='',
                        help='export scene bbox and camera poses for debugging and 3d visualization')
    return parser

@torch.no_grad()
def render_viewpoints(model, render_poses, HW, Ks, ndc, render_kwargs,
                      gt_imgs=None, savedir=None, test_times=None, render_factor=0, eval_psnr=False,
                      eval_ssim=False, eval_lpips_alex=False, eval_lpips_vgg=False, benchmark=False,
                      inverse_y=False, flip_x=False, flip_y=False, batch_size = 4096 * 2, verbose=True, 
                      render_pcd_direct=False, direct_eps=1e-1, render_flow=False, kinematic_warp=True):
    '''Render images for the given viewpoints; run evaluation if gt given.
    '''
    assert len(render_poses) == len(HW) and len(HW) == len(Ks)

    if render_factor!=0:
        HW = np.copy(HW)
        Ks = np.copy(Ks)
        HW = HW // render_factor
        Ks[:, :2, :3] = Ks[:, :2, :3] // render_factor
    
    Ks = torch.tensor(Ks)
    rgbs = []
    depths = []
    weights = []
    flows = []
    psnrs = []
    ssims = []
    lpips_alex = []
    lpips_vgg = []

    for i, c2w in enumerate(tqdm(render_poses, disable = not verbose)):

        H, W = HW[i]
        K = Ks[i]
        rays_o, rays_d, viewdirs = tineuvox.get_rays_of_a_view(
                H, W, K, c2w, ndc, inverse_y=inverse_y, flip_x=flip_x, flip_y=flip_y)
        
        pixel_coords = torch.stack(torch.meshgrid(torch.arange(0, W), torch.arange(0, H)), dim=-1).reshape(-1, 2).to(torch.float32).to(rays_o.device)
        
        rays_o = rays_o.flatten(0,-2)
        rays_d = rays_d.flatten(0,-2)
        viewdirs = viewdirs.flatten(0,-2)
        time_one = test_times[i]*torch.ones_like(rays_o[:,0:1])

        if benchmark:
            # Prime.        
            for ro, rd, vd ,ts in zip(rays_o.split(batch_size, 0), rays_d.split(batch_size, 0), viewdirs.split(batch_size, 0), time_one.split(batch_size, 0)):
                render_kwargs['rays_o'] = ro
                render_kwargs['rays_d'] = rd
                render_kwargs['viewdirs'] = vd
                model(ts[0], render_image=True, render_depth=True, render_kwargs=render_kwargs, render_weights=True, benchmark=benchmark, direct_eps=direct_eps)
                break

            profiler_activities = [torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA ]
            prof = torch.profiler.profile(with_stack=False, profile_memory=False, activities=profiler_activities)
            torch.cuda.synchronize()
            prof.__enter__()
            t0 = time.time()

        if type(model) is not temporalpoints.TemporalPoints:
            keys = ['rgb_marched', 'depth']
            render_result_chunks = [
                {k: v for k, v in model(ro, rd, vd, ts, **render_kwargs).items() if k in keys}
                for ro, rd, vd, ts in zip(rays_o.split(batch_size, 0), rays_d.split(batch_size, 0), viewdirs.split(batch_size, 0),time_one.split(batch_size, 0))
            ]
        else:
            keys = ['rgb_marched', 'depth', 'weights']
            if render_flow: keys.append('flow')
            render_result_chunks = []

            for ro, rd, vd, ts, px in zip(rays_o.split(batch_size, 0), rays_d.split(batch_size, 0), viewdirs.split(batch_size, 0), time_one.split(batch_size, 0), pixel_coords.split(batch_size, 0)):
                # if i == 3:
                #     print(1)
                render_kwargs['rays_o'] = ro
                render_kwargs['rays_d'] = rd
                render_kwargs['viewdirs'] = vd
                render_kwargs['pixel_coords'] = px
                cam_per_ray = torch.zeros(len(ro))[:,None]

                if render_flow:
                    i_delta = max(i-1, 0)
                    flow_t_delta = test_times[i_delta] - test_times[i]
                else:
                    flow_t_delta = None
 
                out = model(ts[0], render_image=True, render_depth=True, render_kwargs=render_kwargs, render_weights=True,
                            benchmark=benchmark, render_pcd_direct=render_pcd_direct, direct_eps=direct_eps, flow_t_delta=flow_t_delta,
                            poses=c2w[None], Ks=Ks[i][None], cam_per_ray=cam_per_ray, kinematic_warp=kinematic_warp)
                if render_pcd_direct:
                    out['rgb_marched'] = out['rgb_marched_direct']
                chunk = {k: v for k, v in out.items() if k in keys}
                render_result_chunks.append(chunk)

        if benchmark:
            torch.cuda.synchronize()
            dt = time.time() - t0
            prof.__exit__(None, None, None)

        if benchmark:
            if prof is not None:
                report = prof.key_averages(group_by_stack_n=10).table(sort_by='cuda_time_total')
                with Path('./logs/benchmark_cuda.txt').open('w') as fid:
                    fid.write(report)
                report = prof.key_averages(group_by_stack_n=10).table(sort_by='cpu_time_total')
                with Path('./logs/benchmark_cpu.txt').open('w') as fid:
                    fid.write(report)
                #prof.export_chrome_trace('./logs/chrome_trace.txt')
                print(report)
            print(f'[Benchmark] dt = {dt*1e3:.6f} ms')
            exit()

        if benchmark:
            continue

        render_result = {
            k: torch.cat([ret[k] for ret in render_result_chunks]).reshape(H,W,-1)
            for k in render_result_chunks[0].keys()
        }
        rgb = render_result['rgb_marched'].cpu().numpy()
        depth = render_result['depth'].cpu().numpy()
        if render_flow:
            flow = render_result['flow'].cpu().numpy()
            flow = flow_to_image(flow)

        try:
            weight = render_result['weights'].cpu().numpy()
            weights.append(weight)
        except:
            pass

        rgbs.append(rgb)
        depths.append(depth)
        if render_flow:
            flows.append(flow)
        
        # if i==0:
        #     print('Testing', rgb.shape)

        if gt_imgs is not None and render_factor == 0:
            if eval_psnr:
                p = -10. * np.log10(np.mean(np.square(rgb - gt_imgs[i])))
                psnrs.append(p)
            if eval_ssim:
                ssims.append(utils.rgb_ssim(rgb, gt_imgs[i], max_val=1))
            if eval_lpips_alex:
                lpips_alex.append(utils.rgb_lpips(rgb, gt_imgs[i], net_name = 'alex', device = c2w.device))
            if eval_lpips_vgg:
                lpips_vgg.append(utils.rgb_lpips(rgb, gt_imgs[i], net_name = 'vgg', device = c2w.device))

    if len(psnrs):
        # create text file and write results into a single file
        if savedir is not None:
            with open(os.path.join(savedir, 'results.txt'), 'w') as f:
                if eval_psnr: f.write('psnr: ' + str(np.mean(psnrs)) + '\n')
                if eval_ssim: f.write('ssim: ' + str(np.mean(ssims)) + '\n')
                if eval_lpips_vgg: f.write('lpips_alex: ' + str(np.mean(lpips_alex)) + '\n')
                if eval_lpips_alex: f.write('lpips_vgg: ' + str(np.mean(lpips_vgg)) + '\n')
        
        if eval_psnr: print('Testing psnr', np.mean(psnrs), '(avg)')
        if eval_ssim: print('Testing ssim', np.mean(ssims), '(avg)')
        if eval_lpips_vgg: print('Testing lpips (vgg)', np.mean(lpips_vgg), '(avg)')
        if eval_lpips_alex: print('Testing lpips (alex)', np.mean(lpips_alex), '(avg)')

    if savedir is not None:
        print(f'Writing images to {savedir}')
        for i in trange(len(rgbs)):
            rgb8 = utils.to8b(rgbs[i])
            filename = os.path.join(savedir, 'img_{:03d}.png'.format(i))
            imageio.imwrite(filename, rgb8)
        
        for i in trange(len(weights)):
            rgb8 = utils.to8b(weights[i])
            filename = os.path.join(savedir, 'weights_{:03d}.png'.format(i))
            imageio.imwrite(filename, rgb8)

    rgbs = np.array(rgbs)
    depths = np.array(depths)
    weights = np.array(weights)
    flows = np.array(flows)

    return rgbs, depths, weights, flows

@torch.no_grad()
def render_repose(rot_params, render_poses, HW, Ks, ndc, model, render_kwargs,
                gt_imgs=None, savedir=None, render_factor=0, eval_psnr=False,
                eval_ssim=False, eval_lpips_alex=False, eval_lpips_vgg=False,
                inverse_y=False, flip_x=False, flip_y=False):
    '''Render images for the given viewpoints; run evaluation if gt given.
    '''
    assert len(render_poses) == len(HW) and len(HW) == len(Ks)
    assert type(model) is temporalpoints.TemporalPoints

    if render_factor!=0:
        HW = np.copy(HW)
        Ks = np.copy(Ks)
        HW //= render_factor
        Ks[:, :2, :3] //= render_factor
    rgbs = []
    depths = []
    weights = []
    psnrs = []
    ssims = []
    lpips_alex = []
    lpips_vgg = []

    for i, c2w in enumerate(tqdm(render_poses)):

        H, W = HW[i]
        K = Ks[i]
        rays_o, rays_d, viewdirs = tineuvox.get_rays_of_a_view(
                H, W, K, c2w, ndc, inverse_y=inverse_y, flip_x=flip_x, flip_y=flip_y)
        
        rays_o = rays_o.flatten(0,-2)
        rays_d = rays_d.flatten(0,-2)
        viewdirs = viewdirs.flatten(0,-2)
        batch_size = 4096 * 2

        keys = ['rgb_marched', 'depth', 'weights']
        render_result_chunks = []

        for ro, rd, vd in zip(rays_o.split(batch_size, 0), rays_d.split(batch_size, 0), viewdirs.split(batch_size, 0)):
            render_kwargs['rays_o'] = ro
            render_kwargs['rays_d'] = rd
            render_kwargs['viewdirs'] = vd
            out = model(None, render_image=True, render_depth=True, render_kwargs=render_kwargs, render_weights=True, rot_params=rot_params[i], calc_min_max=True)
            chunk = {k: v for k, v in out.items() if k in keys}
            render_result_chunks.append(chunk)

        render_result = {
            k: torch.cat([ret[k] for ret in render_result_chunks]).reshape(H,W,-1)
            for k in render_result_chunks[0].keys()
        }
        rgb = render_result['rgb_marched'].cpu().numpy()
        depth = render_result['depth'].cpu().numpy()
        try:
            weight = render_result['weights'].cpu().numpy()
            weights.append(weight)
        except:
            pass

        rgbs.append(rgb)
        depths.append(depth)
        
        if i==0:
            print('Testing', rgb.shape)

        if gt_imgs is not None and render_factor == 0:
            if eval_psnr:
                p = -10. * np.log10(np.mean(np.square(rgb - gt_imgs[i])))
                psnrs.append(p)
            if eval_ssim:
                ssims.append(utils.rgb_ssim(rgb, gt_imgs[i], max_val=1))
            if eval_lpips_alex:
                lpips_alex.append(utils.rgb_lpips(rgb, gt_imgs[i], net_name = 'alex', device = c2w.device))
            if eval_lpips_vgg:
                lpips_vgg.append(utils.rgb_lpips(rgb, gt_imgs[i], net_name = 'vgg', device = c2w.device))

    if len(psnrs):
        if eval_psnr: print('Testing psnr', np.mean(psnrs), '(avg)')
        if eval_ssim: print('Testing ssim', np.mean(ssims), '(avg)')
        if eval_lpips_vgg: print('Testing lpips (vgg)', np.mean(lpips_vgg), '(avg)')
        if eval_lpips_alex: print('Testing lpips (alex)', np.mean(lpips_alex), '(avg)')

    if savedir is not None:
        print(f'Writing images to {savedir}')
        for i in trange(len(rgbs)):
            rgb8 = utils.to8b(rgbs[i])
            filename = os.path.join(savedir, 'img_{:03d}.png'.format(i))
            imageio.imwrite(filename, rgb8)
        
        for i in trange(len(weights)):
            rgb8 = utils.to8b(weights[i])
            filename = os.path.join(savedir, 'weights_{:03d}.png'.format(i))
            imageio.imwrite(filename, rgb8)

    rgbs = np.array(rgbs)
    depths = np.array(depths)
    weights = np.array(weights)

    return rgbs, depths, weights



def seed_everything():
    '''Seed everything for better reproducibility.
    (some pytorch operation is non-deterministic like the backprop of grid_samples)
    '''
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)


def load_everything(args, cfg, use_cache=False, overwrite=False):
    '''Load images / poses / camera settings / data split.
    '''
    cfg.data.skip_images = bool(args.skip_load_images)

    cache_file = Path(cfg.data.datadir) / 'cache.pth'
    if use_cache and not overwrite and cache_file.is_file():
        with cache_file.open("rb") as f:
            data_dict = pickle.load(f)
        return data_dict

    data_dict = load_data(cfg.data, cfg, args.test, args.load_test_val)
    # remove useless field
    kept_keys = {
            'hwf', 'HW', 'Ks', 'near', 'far',
            'i_train', 'i_val', 'i_test', 'irregular_shape',
            'poses', 'render_poses', 'images','times', 'render_times', 'img_to_cam', 'masks'}
    for k in list(data_dict.keys()):
        if k not in kept_keys:
            data_dict.pop(k)

    # construct data tensor
    # NOTE: Could result in memory issues
    # if data_dict['irregular_shape']:
    #     data_dict['images'] = [torch.FloatTensor(im, device='cpu') for im in data_dict['images']]
    # else:
    #     data_dict['images'] = torch.FloatTensor(data_dict['images'], device = 'cpu')
    data_dict['poses'] = torch.tensor(data_dict['poses'], dtype=torch.float32)
    data_dict['times'] = torch.tensor(data_dict['times'], dtype=torch.float32)

    if use_cache:
        with cache_file.open('wb') as f:
            pickle.dump(data_dict, f)

    return data_dict


def compute_bbox_by_cam_frustrm(args, cfg, HW, Ks, poses, i_train, near, far, **kwargs):
    print('compute_bbox_by_cam_frustrm: start')
    xyz_min = torch.tensor([np.inf, np.inf, np.inf])
    xyz_max = -xyz_min
    for (H, W), K, c2w in zip(HW[i_train], Ks[kwargs['img_to_cam'][i_train]], poses[kwargs['img_to_cam'][i_train]]):
        rays_o, rays_d, viewdirs = tineuvox.get_rays_of_a_view(
                H=H, W=W, K=K, c2w=c2w, ndc=cfg.data.ndc, flip_x=cfg.data.flip_x, flip_y=cfg.data.flip_y, inverse_y=cfg.data.inverse_y)
        if cfg.data.ndc:
            pts_nf = torch.stack([rays_o+rays_d*near, rays_o+rays_d*far])
        else:
            pts_nf = torch.stack([rays_o+viewdirs*near, rays_o+viewdirs*far])
        xyz_min = torch.minimum(xyz_min, pts_nf.amin((0,1,2)))
        xyz_max = torch.maximum(xyz_max, pts_nf.amax((0,1,2)))
    print('compute_bbox_by_cam_frustrm: xyz_min', xyz_min)
    print('compute_bbox_by_cam_frustrm: xyz_max', xyz_max)
    print('compute_bbox_by_cam_frustrm: finish')
    return xyz_min, xyz_max


def compute_bbox_by_cam_frustrm_hyper(args, cfg,data_class):
    print('compute_bbox_by_cam_frustrm: start')
    xyz_min = torch.tensor([np.inf, np.inf, np.inf])
    xyz_max = -xyz_min
    for i in data_class.i_train:
        rays_o, _, viewdirs,_ = data_class.load_idx(i,not_dic=True)
        pts_nf = torch.stack([rays_o+viewdirs*data_class.near, rays_o+viewdirs*data_class.far])
        xyz_min = torch.minimum(xyz_min, pts_nf.amin((0,1,2)))
        xyz_max = torch.maximum(xyz_max, pts_nf.amax((0,1,2)))
    print('compute_bbox_by_cam_frustrm: xyz_min', xyz_min)
    print('compute_bbox_by_cam_frustrm: xyz_max', xyz_max)
    print('compute_bbox_by_cam_frustrm: finish')
    return xyz_min, xyz_max

def train_pcd(args, cfg, cfg_model, cfg_train, read_path, save_path, data_dict, tineuvox_model, canonical_t, tensorboard_path):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    writer = tbx.SummaryWriter(tensorboard_path)
    os.makedirs(tensorboard_path, exist_ok=True)
    os.chmod(tensorboard_path, 0o755)

    ## SET UP TRAINING RAYS ##

    HW, Ks, near, far, i_train, i_val, i_test, poses, render_poses, images, times, render_times, masks = [
        data_dict[k] for k in [
            'HW', 'Ks', 'near', 'far', 'i_train', 'i_val', 'i_test', 'poses', 
            'render_poses', 'images',
            'times','render_times','masks'
        ]
    ]
    # times = torch.tensor(times)
    times_i_train = times[i_train].to('cpu' if cfg.data.load2gpu_on_the_fly else device)

    # init rendering setup
    render_kwargs = {
        'near': near,
        'far': far,
        'bg': 1 if cfg.data.white_bkgd else 0,
        'stepsize': cfg_model.stepsize,
    }

    # init batch rays sampler
    rgb_tr, index_to_times, rays_o_tr, rays_d_tr, viewdirs_tr, pix_to_ray, masks_tr = tineuvox.get_training_rays_in_maskcache_sampling(
                    rgb_tr_ori=images,
                    masks_tr_ori=masks,
                    times=times_i_train,
                    train_poses=poses,
                    Ks=Ks,
                    HW=HW,
                    i_train = i_train,
                    ndc=cfg.data.ndc, 
                    model=tineuvox_model, render_kwargs=render_kwargs, img_to_cam = data_dict['img_to_cam'],  **render_kwargs)
    # index_to_times = index_to_times.to(device) # We query times_flatten, need cuda for acceleration

    ## SET UP POINT CLOUDS ##
    # Read pcds and get time

    # batch_size = 64
    # pcd_dataset = PCDDataset(os.path.join(folder_path, 'pcd_paths.json'))
    # pcd_dataloader = iter(DataLoader(pcd_dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True, generator=torch.Generator(device='cuda'), persistent_workers=True, multiprocessing_context='spawn', drop_last=True))
    with open(os.path.join(read_path, 'pcd_paths.json'), "r") as fp:
        time_path_dict =  json.load(fp)
    
    time_keys = list(time_path_dict.keys())
    time_keys = [key for key in time_keys if ('0.' in key) or ('1.' in key)]
    time_keys.sort() # WARNING: sorting based on string representation

    # print('num pcds:', len(time_keys))

    canonical_data = torch.load(os.path.join(read_path, 'pcds', f'{canonical_t}.tar'))
    canonical_pcd = canonical_data['pcd']
    canonical_feat = canonical_data['feat']
    canonical_raw_feat = canonical_data['raw_feat']
    canonical_alpha = canonical_data['alphas']
    canonical_rgbs = canonical_data['rgbs']

    ## SET UP MODEL ##
    last_ckpt_path = os.path.join(save_path, 'temporalpoints_last.tar')
    benchmark_weight_path = os.path.join(save_path, 'benchmark_weights')

    os.makedirs(benchmark_weight_path, exist_ok=True)
    skeleton_data = torch.load(os.path.join(read_path, 'pcds', f'skeleton_{canonical_t}.tar'))
    hierachy = None 
    skeleton_pcd = torch.tensor(skeleton_data['skeleton_pcd'])
    joints = torch.tensor(skeleton_data['joints'])
    bones = skeleton_data['bones']
    # joint_neighbours = skeleton_data['joint_neighbours']

    # init model
    model_kwargs = copy.deepcopy(cfg_model)
    xyz_min = torch.tensor(eval(time_path_dict['xyz_min'])) * model_kwargs['world_bound_scale']
    xyz_max = torch.tensor(eval(time_path_dict['xyz_max'])) * model_kwargs['world_bound_scale']
    voxel_size = eval(time_path_dict['voxel_size'])

    if cfg_train.use_global_view_dir:
        frozen_view_dir = viewdirs_tr.median(dim=0)[0]
    else:
        frozen_view_dir = None
    model = temporalpoints.TemporalPoints(
        canonical_pcd=canonical_pcd,
        canonical_feat=canonical_feat,
        canonical_alpha=canonical_alpha,
        canonical_rgbs=canonical_rgbs,
        skeleton_pcd=skeleton_pcd,
        joints=joints,
        hierachy=hierachy,
        bones=bones,
        xyz_min=xyz_min,
        xyz_max=xyz_max,
        voxel_size=voxel_size,
        tineuvox=tineuvox_model,
        embedding=cfg_train.embedding,
        frozen_view_dir=frozen_view_dir,
        over_parameterized_rot=cfg_train.over_parameterized_rot,
        avg_procrustes=cfg_train.avg_procrustes,
        # joint_neighbours=joint_neighbours,
        **model_kwargs)

    model = model.to(device)
    optimizer = utils.create_optimizer_or_freeze_model(model, cfg_train, global_step=0)

    if cfg_train.load_3DConv:
         model.load_3DConv(last_ckpt_path)
        # model.load_state_dict(torch.load(cfg_train.load_3DConv_path)['model_state_dict'], strict=False)
        # print('load 3DConv model from', cfg_train.load_3DConv_path)

    chamfer_lst = []
    arap_lst = []
    weight_tv_lst = []
    feat_consistency_lst = []
    mse_lst = []
    psnr_lst = []
    trans_reg_loss_lst = []
    joint_arap_lst = []
    joint_chamfer_lst = []
    weight_sparsity_lst = []
    gp_lst = []
    mask_loss_lst = []
    time_tv_lst = []
    chamfer2D_loss_lst = []
    teacher_loss_lst = []
    optical_flow_lst = []
    flow_diff_lst = []
    scene_flow_tv_lst = []
    we_entropy_lst = []

    ## TRAINING ##
    start = 0
    time0 = time.time()
    global_step = -1
    accum_iter = 1
    render_iter = 1 # should be in [1, accum_iter]
    merge_weights_steps = cfg_train.weight_merge_iter
    rho_c = cfg_train.rho_c


    save_iterations = np.unique(np.logspace(np.log10(1), np.log10(cfg_train.N_iters), cfg_train.N_iters // 100).round().astype(int))
    
    num_keep_iters = min(30, len(save_iterations))
    sub_iter_idsx = np.linspace(0, len(save_iterations) - 1, num_keep_iters).round().astype(int)
    save_iterations = save_iterations[sub_iter_idsx]

    # NOTE: To safe memory atm, only save last iteration
    save_iterations = save_iterations[-1:]

    print('save at:', save_iterations)

    ### Tensorboard prep ###
    tb_num_imgs = 5
    tb_factor = 2

    tb_mask = np.random.randint(0, len(images), tb_num_imgs)
    gt_images = images[tb_mask].permute(0,3,1,2)
    resize = torchvision.transforms.Resize([gt_images.shape[2] // tb_factor, gt_images.shape[2] // tb_factor])
    gt_images = resize(gt_images)

    cam_indx = 10
    cam_indices = torch.where(torch.tensor(data_dict['img_to_cam']) == cam_indx)[0]
    cam_indices = cam_indices[torch.linspace(0, len(cam_indices)-1, 40).round().long()].cpu().numpy()

    gt_images_vid = resize(images[cam_indices].to(device).permute(0,3,1,2))
    
    ###

    ### Optical flow prep ###
    # from RAFT.core.raft import RAFT
    # flow_t_delta = torch.sub(*times.unique().sort()[0][:2])
    # class NameSpace:    
    #     def __init__(self):
    #         self.model = 'RAFT/models/raft-things.pth'
    #         self.small = False
    #         self.alternate_corr = False
    #         self.mixed_precision = False

    #         self.keys = [attr for attr in dir(self) if not callable(getattr(self, attr)) and not attr.startswith("__")]
    #         self._current_index = 0 
            
    #     def __iter__(self):
    #         return self    

    #     def __next__(self):
    #         if self._current_index < len(self.keys):
    #             self._current_index += 1
    #             return self.keys[self._current_index-1]      
    #         raise StopIteration

    # raft_args = NameSpace()
    # raft = torch.nn.DataParallel(RAFT(raft_args))
    # raft.load_state_dict(torch.load(raft_args.model))

    # raft = raft.module
    # raft.to(device)
    # raft.eval()

    ###



    pre_train_iter = cfg_train.pre_train_iter
    pre_train_done = False if pre_train_iter > 1 else True

    after_train_iter = cfg_train.after_train_iter
    do_after_train = False

    conv_render_until = cfg_train.conv_render_until

    # renderer_only_range = cfg_train.renderer_only_range
    # assert (renderer_only_range is not None) ^ (not pre_train_iter == 0)

    if not pre_train_done: # Only fix joints for pretraining
        for _, param_group in enumerate(optimizer.param_groups):
            p_name = param_group['name']
            fixed = ['joints']
            if p_name in fixed:
                param_group['lr'] = 0

    # Assume that we always have at least 10 timesteps
    canonical_t_indx = torch.argmin(((torch.tensor(np.array(time_keys).astype(float)) - cfg.data.canonical_t)**2).sqrt()).long()
    canonical_t_indx = canonical_t_indx.item()
    def get_range(max_len, num=10):
        t_max = math.ceil(canonical_t_indx + num / 2)
        t_min = math.ceil(canonical_t_indx - num / 2)

        if num >= max_len:
            t_min = 0
            t_max = max_len
        elif t_max > max_len:
            overflow = t_max % max_len
            t_min -= overflow
            t_max = max_len
        elif t_min < 0:
            underflow = abs(t_min)
            t_max += underflow
            t_min = 0

        return t_max, t_min

    sampler = utils.InverseProportionalSampler(len(time_keys))

    for global_step in trange(1+start, 1+cfg_train.N_iters):
        optimizer.zero_grad(set_to_none = True)

        # Imitate batch, accumulate gradient (in the future use actual batches)
        for ac_i in range(accum_iter):
            
            ## SAMPLE TIME
            if not pre_train_done or do_after_train:
                t_max, t_min = get_range(len(time_keys), cfg_train.pre_train_t_num)
            else:
                offset = 1
                num = min(max((len(time_keys) / cfg_train.full_t_iter) * (global_step - pre_train_iter), 1), len(time_keys))
                # num =  len(time_keys)
                t_max, t_min = get_range(len(time_keys), num)

            rnd_i = torch.randint(t_min, t_max, [1])
            # rnd_i = sampler.sample(t_min, t_max)
            
            time_key = time_keys[rnd_i.item()]
            time_sel = torch.tensor([float(time_key)])

            # t_data = torch.load(time_path_dict[time_key])
            # t_pcd = t_data['pcd']

            ## SAMPLE RAYS ##
            render_image = ((ac_i + 1) % render_iter) == 0
            if render_image:
                # b = (time_sel == times_flatten).nonzero()[:,0]
                b_range = index_to_times[time_sel.item()]
                # b = torch.arange(b_range[0], b_range[1])

                # sel_i = b[torch.randint(b.shape[0], [cfg_train.N_rand])].squeeze(-1)

                # sel_i = (torch.tensor(np.random.choice(b_range[1], cfg_train.N_rand, replace=False)) + b_range[0]).long().to(rgb_tr.device) 
                sel_i = torch.randint(b_range[0], b_range[1], (cfg_train.N_rand,)).long().to(rgb_tr.device)
                # sel_r = torch.randint(rays_o_tr.shape[1], [cfg_train.N_rand])
                # sel_c = torch.randint(rays_o_tr.shape[2], [cfg_train.N_rand])

                img_i = torch.div(sel_i, (images.shape[1] * images.shape[1]), rounding_mode='floor').unsqueeze(-1)
                cam_per_ray = img_i % len(poses)
                # mask = (img_i[:,0], pix_x[:,0], pix_y[:,0])

                target = rgb_tr[sel_i]
                target_alpha_inv_last = 1 - masks_tr[sel_i]
                sel_ray = pix_to_ray[sel_i].long()
                rays_o = rays_o_tr[sel_ray]
                rays_d = rays_d_tr[sel_ray]
                viewdirs = viewdirs_tr[sel_ray]

                if cfg.data.load2gpu_on_the_fly:
                    target = target.to(device)
                    rays_o = rays_o.to(device)
                    rays_d = rays_d.to(device)
                    viewdirs = viewdirs.to(device)
                    time_sel = time_sel.to(device)
                    # t_pcd = t_pcd.to(device)
                    target_alpha_inv_last = target_alpha_inv_last.to(device)
                    # pixel_coords = pixel_coords.to(device)

                # render_kwargs['pixel_coords'] = pixel_coords
                render_kwargs['rays_o'] = rays_o
                render_kwargs['rays_d'] = rays_d
                render_kwargs['viewdirs'] = viewdirs
            else:
                render_kwargs['rays_o'] = None
                render_kwargs['rays_d'] = None
                render_kwargs['viewdirs'] = None

            # render_pcd_direct = torch.rand(1) < 0.1
            render_pcd_direct = False
            # r1 = 0.8
            # r2 = 1
            # direct_eps = (r1 - r2) * torch.rand(1) + r2
            render_conv = global_step <= conv_render_until
            direct_eps = 1e-1
            kinematic_warp = cfg_train.kinematic_warp
            res = model(time_sel, render_image, False, render_kwargs, render_pcd_direct=render_pcd_direct, direct_eps=direct_eps, 
                        kinematic_warp=kinematic_warp, poses=poses, Ks=torch.tensor(Ks), cam_per_ray=cam_per_ray, flow_t_delta=None, render_conv=render_conv)
            t_hat_pcd = res['t_hat_pcd']
            conv_t_hat_pcd = res['conv_t_hat_pcd']
            rgb_marched = res['rgb_marched']
            rgb_marched_direct = res['rgb_marched_direct']

            ## LOSSESS ##
            loss = 0
            if cfg_train.weight_render > 0 and render_image:
                mse_loss = F.mse_loss(rgb_marched, target)
                mse_loss_direct = 0
                if rgb_marched_direct is not None and cfg_train.use_direct_loss and cfg_train.full_t_iter <= global_step:
                    mse_loss_direct = F.mse_loss(rgb_marched_direct, target)

                img_loss = mse_loss + mse_loss_direct
                
                psnr = utils.mse2psnr(mse_loss.clone().detach())
                mse_loss = img_loss * cfg_train.weight_render
                
                mse_lst.append(mse_loss.item())
                psnr_lst.append(psnr.item())
                loss += mse_loss

            if cfg_train.weight_time_tv > 0 and pre_train_done:
                rnd_i_n = min(rnd_i.item() + 1, len(time_keys) - 1)
                time_key_n = time_keys[rnd_i_n]
                time_sel_n = torch.tensor([float(time_key_n)])
                time_tv_loss = model.get_time_tv_loss(time_sel_n)
                time_tv_lst.append(time_tv_loss.item())
                loss += time_tv_loss

            if cfg_train.weight_scene_flow_tv > 0 and pre_train_done and render_conv:
                if res['grid'] is not None:
                    scene_flow_tv_loss = cfg_train.weight_scene_flow_tv * res['grid'].abs().mean()
                    scene_flow_tv_lst.append(scene_flow_tv_loss.item())
                    loss += scene_flow_tv_loss
            
            if cfg_train.weight_teacher > 0 and pre_train_done:
                if res['rgb_marched_direct'] is not None:
                    teacher_loss = cfg_train.weight_teacher * (res['rgb_marched'] - res['rgb_marched_direct']).pow(2).mean()
                    teacher_loss_lst.append(teacher_loss.item())
                    loss += teacher_loss
                
            if cfg_train.weight_chamfer > 0 and pre_train_done:
                chamfer_loss = model.get_chamfer_loss(t_hat_pcd, t_pcd, N=None, c=rho_c) 
                if render_conv:
                    chamfer_loss += model.get_chamfer_loss(conv_t_hat_pcd, t_pcd, N=None, c=rho_c)
                chamfer_loss = cfg_train.weight_chamfer * chamfer_loss
                chamfer_lst.append(chamfer_loss.item())
                loss += chamfer_loss

            if cfg_train.weight_arap > 0 and pre_train_done:
                arap_loss = cfg_train.weight_arap * model.get_arap_loss(t_hat_pcd)
                if render_conv:
                    arap_loss += cfg_train.weight_arap * model.get_arap_loss(conv_t_hat_pcd)
                arap_lst.append(arap_loss.item())
                loss += arap_loss

            if cfg_train.weight_collision_loss > 0 and pre_train_done:
                gp_loss = cfg_train.weight_collision_loss * model.collision_loss(t_hat_pcd)
                if render_conv:
                    gp_loss += cfg_train.weight_collision_loss * model.collision_loss(conv_t_hat_pcd)
                gp_lst.append(gp_loss.item())
                loss += gp_loss

            # if cfg_train.weight_mask_loss > 0:
            #     if res['alphainv_last'] is not None:
            #         pout = res['alphainv_last'].clamp(1e-6, 1-1e-6).unsqueeze(-1)
            #         mask_loss = cfg_train.weight_mask_loss * F.binary_cross_entropy(pout, target_alpha_inv_last)
            #         mask_loss_lst.append(mask_loss.item())
            #         loss += mask_loss

            if cfg_train.weight_tv > 0 and pre_train_done:
                weight_tv_loss = cfg_train.weight_tv * model.get_neighbour_weight_tv_loss()
                weight_tv_lst.append(weight_tv_loss.item())
                loss += weight_tv_loss
            
            if cfg_train.weight_we_entropy > 0 and pre_train_done:
                we_entropy_loss = cfg_train.weight_we_entropy * model.get_we_entropy_loss(t_min, t_max)
                we_entropy_lst.append(we_entropy_loss.item())
                loss += we_entropy_loss

            # if cfg_train.weight_feat_consistency > 0:
            #     feat_consistency_loss = cfg_train.weight_feat_consistency * model.feature_consistency_loss(
            #                                                             pred_pcd = t_hat_pcd, gt_pcd = t_pcd, 
            #                                                             pred_feat = canonical_feat, gt_feat = t_feat)
            #     feat_consistency_lst.append(feat_consistency_loss.item())
            #     loss += feat_consistency_loss
            
            if cfg_train.weight_transformation_reg > 0 and pre_train_done:
                trans_reg_loss = cfg_train.weight_transformation_reg * model.get_transformation_regularisation_loss()
                trans_reg_loss_lst.append(trans_reg_loss.item())
                loss += trans_reg_loss

            if cfg_train.weight_joint_arap > 0 and pre_train_done:
               joint_arap_loss = cfg_train.weight_joint_arap * model.get_joint_arap_loss()
               joint_arap_lst.append(joint_arap_loss.item())
               loss += joint_arap_loss

            if cfg_train.weight_joint_chamfer > 0 and pre_train_done:
               joint_chamfer_loss = cfg_train.weight_joint_chamfer * model.get_joint_chamfer_loss()
               joint_chamfer_lst.append(joint_chamfer_loss.item())
               loss += joint_chamfer_loss

            if (cfg_train.weight_sparsity > 0) and pre_train_done: # and (global_step > cfg_train.full_t_iter)
                weight_sparsity_loss = cfg_train.weight_sparsity * model.get_weight_sparsity_loss()
                weight_sparsity_lst.append(weight_sparsity_loss.item())
                loss += weight_sparsity_loss

            if cfg_train.weight_flow_diff > 0 and pre_train_done and render_conv:
                if res['flow_diff'] is not None:
                    flow_diff_los = cfg_train.weight_flow_diff * res['flow_diff'].pow(2).mean()
                    flow_diff_lst.append(flow_diff_los.item())
                    loss += flow_diff_los

            if cfg_train.weight_chamfer2D and pre_train_done:
                # Select a random time step
                chamfer_mask = torch.where(times == time_sel)[0]

                num_rnd_cam_i = min(5, len(chamfer_mask))
                rnd_cam_i = torch.randperm(len(chamfer_mask))[:num_rnd_cam_i].long()
                # rnd_cam_i = torch.randint(0, len(chamfer_mask), (num_rnd_cam_i,))
                chamfer_mask = chamfer_mask[rnd_cam_i]
                if cfg_train.pose_one_each: # Nerf
                    poses_temp = data_dict['poses'][chamfer_mask].squeeze(-1).to(device)
                    Ks_temp = torch.tensor(data_dict['Ks'][chamfer_mask]).unsqueeze(0).to(device).to(torch.float32)
                else:
                    # Select a random pose
                    poses_temp = data_dict['poses'].to(device)[rnd_cam_i]
                    Ks_temp = torch.tensor(data_dict['Ks']).to(device).to(torch.float32)[rnd_cam_i]
                
                chamfer2D_loss = 0

                iter_pcds = [t_hat_pcd]
                if render_conv:
                    iter_pcds.append(conv_t_hat_pcd)

                for pcd in iter_pcds:
                    projected_points_hat = utils.project_point_to_image_plane(pcd, poses_temp, Ks_temp)
                    # # # Mask Compability
                    projected_points_hat[:,:,0] = (images.shape[1] - 1) - projected_points_hat[:,:,0]
                    projected_points_hat = projected_points_hat.flip(-1)
                    # Mask Compability end

                    M = 3000
                    N = 3000
                    masks_iter = masks[chamfer_mask.cpu()].squeeze(-1)

                    mask_pcd = [torch.cat([mask.unsqueeze(-1) for mask in torch.where(masks_iter[i] > 0)], dim=-1) for i in range(len(masks_iter))]
                    mask_pcd = torch.cat([mask[torch.randint(0, mask.shape[0], (M,)).long().cpu()].unsqueeze(0) for mask in mask_pcd], dim=0).to(projected_points_hat.device).float()

                    chamfer2D_loss += model.get_batch_chamfer_loss(projected_points_hat, mask_pcd, N=N, M=None)
                
                chamfer2D_loss *= cfg_train.weight_chamfer2D
                chamfer2D_loss_lst.append(chamfer2D_loss.item())
                loss += chamfer2D_loss

                # size = images.shape[1]
                # projected_points_hat = projected_points_hat.round().clip(0,size-1)
                # projected_points_hat = projected_points_hat.detach().cpu().numpy().astype(np.int32)
                # mask_pcd = mask_pcd.detach().cpu().numpy().astype(np.int32)

                # import matplotlib.pyplot as plt
                # for i, j in zip(range(len(projected_points_hat)), torch.where(times == time_sel)[0]):

                #     img = np.zeros((size, int(size * 3), 3))
                #     img[projected_points_hat[i, :, 0], projected_points_hat[i, :, 1]] = 1
                #     img[mask_pcd[i, :, 0], size + mask_pcd[i, :, 1]] = 1

                #     # img[:, 512:1024, :] = masks[j].clip(0, 1).numpy().repeat(3, axis=-1)
                #     img[:, size*2:, :] = images[j].clip(0, 1).numpy()
                #     plt.imsave(f'test_{i}.png', img)

            # # if cfg_train.weight_optical_flow > 0 and pre_train_done:
            #    pass

            # if cfg_train.weight_optical_flow and pre_train_done:
            #     pix_y = ((sel_i % (images.shape[1] * images.shape[1])) % images.shape[1]).unsqueeze(-1)
            #     pix_x = torch.div(sel_i % (images.shape[1] * images.shape[1]), images.shape[1], rounding_mode='floor').unsqueeze(-1)
            #     rel_img_i = img_i % len(poses)
            #     img_i = img_i.unique().sort()[0]
            #     img_i_prev = img_i - len(poses)
            #     underflow_mask = torch.where(img_i_prev < 0)
            #     img_i_prev[underflow_mask] = img_i[underflow_mask]

            #     i_1 = img_i_prev
            #     i_2 = img_i

            #     images2 = images[i_2].permute(0,3,1,2).to(device) * 255
            #     images1 = images[i_1].permute(0,3,1,2).to(device) * 255
                
            #     with torch.no_grad():
            #         _, flow_up = raft(images2, images1, iters=1, test_mode=True)
            #         flow_up *= masks[i_2].permute(0,3,1,2).to(flow_up.device)
            #         flow_up = flow_up.permute(0,2,3,1)

            #     optical_flow_mask = (rel_img_i[:,0], pix_x[:,0], pix_y[:,0])
            #     optical_flow_loss = cfg_train.weight_optical_flow * (flow_up[optical_flow_mask] - res['flow']).pow(2).mean()
            #     optical_flow_lst.append(optical_flow_loss.item())

            #     loss += optical_flow_loss


            # Visualisation Sanity Checky Checky
            # import matplotlib.pyplot as plt
            # i_1 = 15
            # i_2 = i_1 + len(poses)
            # for _ in range(20):

            #     image2 = images[i_2].permute(2,0,1)[None].to(device) * 255
            #     image1 = images[i_1].permute(2,0,1)[None].to(device) * 255
            
            #     _, flow_up = raft(
            #         image2,
            #         image1, 
            #         iters=20,
            #         test_mode=True)
                
            #     flow_up *= masks[i_2].permute(2,0,1)[None].to(flow_up.device)
            #     flow_up = flow_up[0].permute(1,2,0).detach().cpu().numpy()
            #     flow_up = torch.tensor(flow_to_image(flow_up)).detach().cpu().numpy()

            #     # f, axarr = plt.subplots(3,1) 
            #     # axarr[0].imshow((image2[0] / 255).permute(1,2,0).detach().cpu().numpy())
            #     # axarr[1].imshow((image1[0] / 255).permute(1,2,0).detach().cpu().numpy())
            #     # axarr[2].imshow(flow_up)
            #     f, axarr = plt.subplots(1,1) 
            #     axarr.imshow(flow_up)

            #     plt.show()
            #     plt.close()
            #     i_1 = i_2
            #     i_2 = i_1 + len(poses)
            # for i, j in zip(img_i_prev.squeeze(-1), img_i.squeeze(-1)):

            #     plt.close()
            #     _, flow_up = raft(images[j].permute(2,1,0)[None].to(device) * 255, images[i].permute(2,1,0)[None].to(device), iters=20, test_mode=True)
            #     flow_up = flow_up[0].permute(1,2,0)
            #     flow_up *= masks[j].to(device)
            #     flow_up = flow_up.detach().cpu().numpy()
            #     flow_up= flow_to_image(flow_up)
                
            #     #subplot(r,c) provide the no. of rows and columns
            #     f, axarr = plt.subplots(3,1) 

            #     # use the created array to output your multiple images. In this case I have stacked 4 images vertically
            #     axarr[0].imshow(images[i])
            #     axarr[1].imshow(images[j])
            #     axarr[2].imshow(flow_up)
            #     print(i,j)
            #     plt.show()

            # mask = (img_i[:,0], pix_x[:,0], pix_y[:,0])


            loss = loss / accum_iter
            loss.backward()

        optimizer.step()

        # update lr
        decay_steps = cfg_train.lrate_decay * 1000
        decay_factor = 0.1 ** (1/decay_steps)
        for i_opt_g, param_group in enumerate(optimizer.param_groups):
            param_group['lr'] = param_group['lr'] * decay_factor

        if (global_step == pre_train_iter) and (not pre_train_done):
            pre_train_done = True
            optimizer = utils.create_optimizer_or_freeze_model(model, cfg_train, global_step=0)
            fixed = ['rgbnet', 'densitynet', 'featurenet', 'canonical_feat', 'feat_net']
            model.reinitialise_weights()

            for _, param_group in enumerate(optimizer.param_groups):
                p_name = param_group['name']
                if p_name in fixed:
                    param_group['lr'] = 1e-6

            for layer in next(model.forward_warp.transform_net.children()):
                if hasattr(layer, 'reset_parameters'):
                    layer.reset_parameters()
                        
        if (global_step == after_train_iter) and (not do_after_train):
            do_after_train = True
            not_fixed = ['rgbnet', 'densitynet', 'featurenet', 'canonical_feat', 'feat_net']

            for _, param_group in enumerate(optimizer.param_groups):
                p_name = param_group['name']
                if not p_name in not_fixed: # fix non render weights
                    param_group['lr'] = 0.
                else: # fine-tune point renderer
                    param_group['lr'] = 1e-6

        # check log & save
        if global_step%args.i_print == 0:
            eps_time = time.time() - time0
            eps_time_str = f'{eps_time//3600:02.0f}:{eps_time//60%60:02.0f}:{eps_time%60:02.0f}'
            # tqdm.write(f'pcd training : iter {global_step:6d} / PSNR: {np.mean(psnr_lst):5.2f} / '
            #            f'Chamfer: {np.mean(chamfer_lst):.9f} / IMG Loss: {np.mean(mse_lst):.9f} / '
            #            f'ARAP: {np.mean(arap_lst):.9f} / Weight TV: {np.mean(weight_tv_lst):.9f} / '
            #            f'Joint ARAP: {np.mean(joint_arap_lst):.9f} / Trans. Reg.: {np.mean(trans_reg_loss_lst):.9f} / '
            #            f'Joint Chamfer: {np.mean(joint_chamfer_lst):.9f} / Feat. Cons.: {np.mean(feat_consistency_lst):.9f} / '
            #            f'Weight Sparsity: {np.mean(weight_sparsity_lst):.9f} / Grid Preserve: {np.mean(gp_lst):.9f}  / Eps: {eps_time_str}')

            tqdm.write(f'pcd training : iter {global_step:6d} / PSNR: {np.mean(psnr_lst):5.2f} / t_range: {t_min:.2f}-{t_max:.2f} / Eps: {eps_time_str}')

            writer.add_scalar("metrics/PSNR", np.mean(psnr_lst), global_step)
            writer.add_scalar("metrics/Chamfer", np.mean(chamfer_lst), global_step)
            writer.add_scalar("metrics/IMG_Loss", np.mean(mse_lst), global_step)
            writer.add_scalar("metrics/ARAP", np.mean(arap_lst), global_step)
            writer.add_scalar("metrics/Weight_TV", np.mean(weight_tv_lst), global_step)
            writer.add_scalar("metrics/Joint_ARAP", np.mean(joint_arap_lst), global_step)
            writer.add_scalar("metrics/Trans._Reg.", np.mean(trans_reg_loss_lst), global_step)
            writer.add_scalar("metrics/Joint_Chamfer", np.mean(joint_chamfer_lst), global_step)
            writer.add_scalar("metrics/Weight_Sparsity", np.mean(weight_sparsity_lst), global_step)
            writer.add_scalar("metrics/Mask_Loss", np.mean(mask_loss_lst), global_step)
            writer.add_scalar("metrics/Collision_Loss", np.mean(gp_lst), global_step)
            writer.add_scalar("metrics/Time_TV", np.mean(time_tv_lst), global_step)
            writer.add_scalar("metrics/Chamfer2D", np.mean(chamfer2D_loss_lst), global_step)
            writer.add_scalar("metrics/Teacher_Loss", np.mean(teacher_loss_lst), global_step)
            writer.add_scalar("metrics/optical_flow", np.mean(optical_flow_lst), global_step)
            writer.add_scalar("metrics/Flow_Diff", np.mean(flow_diff_lst), global_step)
            writer.add_scalar("metrics/Scene_Flow_TV", np.mean(scene_flow_tv_lst), global_step)
            writer.add_scalar("metrics/WE_Entropy", np.mean(we_entropy_lst), global_step)
            writer.add_scalar("metrics/eps_time", eps_time, global_step)

            mse_lst = []
            psnr_lst = []
            arap_lst = []
            chamfer_lst = []
            weight_tv_lst = []
            joint_arap_lst = []
            joint_chamfer_lst = []
            trans_reg_loss_lst = []
            weight_sparsity_lst = []
            feat_consistency_lst = []
            mask_loss_lst = []
            time_tv_lst = []
            chamfer2D_loss_lst = []
            optical_flow_lst = []
            flow_diff_lst = []
            scene_flow_tv_lst = []
            we_entropy_lst = []
        
        if (global_step % args.i_save == 0) or (global_step == 1): 
            
            ## Render training images
            pred_images, _, pred_weights, _ = render_viewpoints(
                model, 
                render_poses=data_dict['poses'][data_dict['img_to_cam'][tb_mask]],
                HW=data_dict['HW'][tb_mask],
                Ks=data_dict['Ks'][data_dict['img_to_cam'][tb_mask]],
                # gt_imgs=[data_dict['images'][i].cpu().numpy() for i in mask],
                test_times=data_dict['times'][tb_mask],
                ndc=cfg.data.ndc,
                render_kwargs=render_kwargs,
                batch_size = 4096,
                render_factor=tb_factor,
                verbose=False,
                render_flow=False,
                render_pcd_direct=False,
                kinematic_warp=kinematic_warp)
                
            pred_images = torch.tensor(pred_images).permute(0,3,1,2)
            pred_weights = torch.tensor(pred_weights).permute(0,3,1,2)
            payload = torch.concat([gt_images.to('cuda'), pred_images, pred_weights], dim=0)
            writer.add_image('payload', torchvision.utils.make_grid(payload, nrow=tb_num_imgs), global_step=global_step)

            # Render spherical pcd video
            # rgbs, disps, _ = render_viewpoints(
            #         model,
            #         render_poses=data_dict['render_poses'],
            #         HW=data_dict['HW'][0][None,...].repeat(len(data_dict['render_poses']), 0),
            #         Ks=data_dict['Ks'][0][None,...].repeat(len(data_dict['render_poses']), 0),
            #         test_times=torch.linspace(0., 1., len(data_dict['render_poses'])),
            #         render_kwargs=render_kwargs,
            #         ndc=cfg.data.ndc,
            #         batch_size = 4096 * 2,
            #         render_factor=tb_factor,
            #         verbose=False,
            #         render_pcd_direct=True)
            
            # rgbs = torch.tensor(rgbs).permute(0,3,1,2).unsqueeze(0)
            # disps = torch.tensor(disps).permute(0,3,1,2).unsqueeze(0)
            # disps /= torch.max(disps)

            # writer.add_video('pcd_rgb', rgbs, global_step=global_step, fps=10)
            # writer.add_video('pcd_disp', disps, global_step=global_step, fps=10)

            # Render static cam comparison video
            # Render full model
            pred_images, _, pred_weights, _ = render_viewpoints(
                model, 
                render_poses=data_dict['poses'][data_dict['img_to_cam'][cam_indices]],
                HW=data_dict['HW'][cam_indices],
                Ks=data_dict['Ks'][data_dict['img_to_cam'][cam_indices]],
                test_times=torch.linspace(0,1, len(cam_indices)), # data_dict['times'][cam_indices],
                ndc=cfg.data.ndc,
                render_kwargs=render_kwargs,
                batch_size = 4096 * 2,
                render_factor=tb_factor,
                verbose=False,
                direct_eps=direct_eps,
                render_flow=False,
                render_pcd_direct=False,
                kinematic_warp=True)
            
            # Render PCD based on frozen rgb and alphas
            pred_images_pcd, _, _, _ = render_viewpoints(
                    model, 
                    render_poses=data_dict['poses'][data_dict['img_to_cam'][cam_indices]],
                    HW=data_dict['HW'][cam_indices],
                    Ks=data_dict['Ks'][data_dict['img_to_cam'][cam_indices]],
                    test_times=torch.linspace(0,1, len(cam_indices)), # data_dict['times'][cam_indices],
                    ndc=cfg.data.ndc,
                    render_kwargs=render_kwargs,
                    batch_size = 4096 * 2,
                    render_factor=tb_factor,
                    verbose=False,
                    render_flow=False,
                    render_pcd_direct=True,
                    kinematic_warp=True)
            pred_images_pcd = torch.tensor(pred_images_pcd).permute(0,3,1,2)
            pred_images = torch.tensor(pred_images).permute(0,3,1,2)
            pred_weights = torch.tensor(pred_weights).permute(0,3,1,2)
            video_temp = torch.concat([gt_images_vid, pred_images_pcd, pred_images, pred_weights], dim=3).unsqueeze(0)

            writer.add_video('video', video_temp, global_step=global_step, fps=4)

        if ((global_step % merge_weights_steps) == 0) and pre_train_done:
            print('\nMerging weights')
            _times = torch.linspace(0., 1., 300).unsqueeze(-1)
            model.simplify_skeleton(_times, deg_threshold=cfg_train.degree_threshold, five_percent_heuristic=False)


        # if global_step == merge_weights_steps:
        # # if (global_step % merge_weights_steps) == 0:
        #     print('\nMerging weights')
        #     _times = torch.linspace(0., 1., 300).unsqueeze(-1)
        #     model.simplify_lbs_weights(_times, deg_threshold=cfg_train.degree_threshold, noise_std=0)

        #     # Reset learning rates for weight specific parameters
        #     for _, param_group in enumerate(optimizer.param_groups):
        #         p_name = param_group['name']
        #         if p_name == 'weights' or p_name  == 'theta_weight' or p_name == 'feat_net':
        #             param_group['lr'] = getattr(cfg_train, f'lrate_{p_name}')

        if global_step in save_iterations:
            iter_path = os.path.join(benchmark_weight_path, f'temporalpoints_{global_step}.tar')
            torch.save({
                'global_step': global_step,
                'model_kwargs': model.get_kwargs(),
                'model_state_dict': model.state_dict(),
            }, iter_path)
            print('pcd training: saved checkpoints at', iter_path)

    if global_step != -1:
        torch.save({
            'global_step': global_step,
            'model_kwargs': model.get_kwargs(),
            'model_state_dict': model.state_dict(),
        }, last_ckpt_path)
        print('pcd training: saved checkpoints at', last_ckpt_path)

def scene_rep_reconstruction(args, cfg, cfg_model, cfg_train, xyz_min, xyz_max, data_dict):
    # init
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    if abs(cfg_model.world_bound_scale - 1) > 1e-9:
        xyz_shift = (xyz_max - xyz_min) * (cfg_model.world_bound_scale - 1) / 2
        xyz_min -= xyz_shift
        xyz_max += xyz_shift


    HW, Ks, near, far, i_train, i_val, i_test, poses, render_poses, images, times, render_times, masks = [
        data_dict[k] for k in [
            'HW', 'Ks', 'near', 'far', 'i_train', 'i_val', 'i_test', 'poses', 
            'render_poses', 'images',
            'times','render_times', 'masks'
        ]
    ]
    # times = torch.tensor(times)
    times_i_train = times[i_train].to('cpu' if cfg.data.load2gpu_on_the_fly else device)



    last_ckpt_path = os.path.join(cfg.basedir, cfg.expname, 'fine_last.tar')
    if os.path.isfile(last_ckpt_path):
        print('fine_last.tar already exists, skipping training.')
        return # right now, if there is already a file in the directory, just skip this function

    # init model and optimizer
    start = 0
    # init model
    model_kwargs = copy.deepcopy(cfg_model)

    num_voxels = model_kwargs.pop('num_voxels')
    if len(cfg_train.pg_scale) :
        num_voxels = int(num_voxels / (2**len(cfg_train.pg_scale)))
    model = tineuvox.TiNeuVox(
        xyz_min=xyz_min, xyz_max=xyz_max,
        num_voxels=num_voxels,
        **model_kwargs)
    model = model.to(device)
    optimizer = utils.create_optimizer_or_freeze_model(model, cfg_train, global_step=0)

    # init rendering setup
    render_kwargs = {
        'near': near,
        'far': far,
        'bg': 1 if cfg.data.white_bkgd else 0,
        'stepsize': cfg_model.stepsize,
        'inverse_y': cfg.data.inverse_y, 
        'flip_x': cfg.data.flip_x,
        'flip_y': cfg.data.flip_y,
    }

    # def gather_training_rays():
    #     # if data_dict['irregular_shape']:
    #     #     rgb_tr_ori = [images[i].to('cpu' if cfg.data.load2gpu_on_the_fly else device) for i in i_train]
    #     # else:
    #     #     rgb_tr_ori = images[i_train].to('cpu' if cfg.data.load2gpu_on_the_fly else device)

    #     if cfg_train.ray_sampler == 'in_maskcache':
    #         print('cfg_train.ray_sampler =in_maskcache')
    #         # NOTE: No mask cache here, only checking whether ray is within feature volume
    #         rgb_tr, times_flatten,rays_o_tr, rays_d_tr, viewdirs_tr = tineuvox.get_training_rays_in_maskcache_sampling(
    #                 rgb_tr_ori=images,times=times_i_train,
    #                 train_poses=poses,
    #                 HW=HW, Ks=Ks,
    #                 ndc=cfg.data.ndc, 
    #                 model=model, render_kwargs=render_kwargs, **render_kwargs)
    #     # elif cfg_train.ray_sampler == 'flatten':
    #     #     print('cfg_train.ray_sampler =flatten')
    #     #     rgb_tr, times_flatten,rays_o_tr, rays_d_tr, viewdirs_tr, imsz = tineuvox.get_training_rays_flatten(
    #     #         rgb_tr_ori=rgb_tr_ori,times=times_i_train,
    #     #         train_poses=poses[i_train],
    #     #         HW=HW[i_train], Ks=Ks[i_train], ndc=cfg.data.ndc, **render_kwargs)
    #     # else:
    #     #     print('cfg_train.ray_sampler =random')
    #     #     rgb_tr, times_flatten,rays_o_tr, rays_d_tr, viewdirs_tr, imsz = tineuvox.get_training_rays(
    #     #         rgb_tr=rgb_tr_ori,times=times_i_train,
    #     #         train_poses=poses[i_train],
    #     #         HW=HW[i_train], Ks=Ks[i_train], ndc=cfg.data.ndc, **render_kwargs)
    #     index_generator = tineuvox.batch_indices_generator(len(rgb_tr), cfg_train.N_rand)
    #     batch_index_sampler = lambda: next(index_generator)
    #     return rgb_tr,times_flatten, rays_o_tr, rays_d_tr, viewdirs_tr, imsz, batch_index_sampler
    
    # rgb_tr, times_flatten, rays_o_tr, rays_d_tr, viewdirs_tr, imsz, batch_index_sampler = gather_training_rays()

    # NOTE: No mask cache here, only checking whether ray is within feature volume
    # NOTE: CAREFUL, PASSING ALL IMAGES NOT ITRAIN IMAGES 
    rgb_tr, times_flatten, rays_o_tr, rays_d_tr, viewdirs_tr, pix_to_ray, masks_tr = tineuvox.get_training_rays_in_maskcache_samplingTINEUVOX(
                    rgb_tr_ori=images,
                    masks_tr_ori=masks,
                    times=times_i_train,
                    train_poses=poses,
                    i_train = i_train,
                    Ks=Ks,
                    HW=HW[i_train],
                    ndc=cfg.data.ndc, 
                    model=model, render_kwargs=render_kwargs, img_to_cam = data_dict['img_to_cam'],  **render_kwargs)
    index_generator = tineuvox.batch_indices_generator(len(rgb_tr), cfg_train.N_rand)
    batch_index_sampler = lambda: next(index_generator)

    
    temp = data_dict['images']
    data_dict['images'] = None
    del temp
    del images

    torch.cuda.empty_cache()
    psnr_lst = []
    dist_loss_lst = []
    inv_loss_lst = []
    delta_loss_lst = []
    mask_loss_lst = []
    time0 = time.time()
    global_step = -1

    for global_step in trange(1+start, 1+cfg_train.N_iters):

        if global_step == args.step_to_half:
            model.feature.data=model.feature.data.half()
        # progress scaling checkpoint
        if global_step in cfg_train.pg_scale:
            n_rest_scales = len(cfg_train.pg_scale)-cfg_train.pg_scale.index(global_step)-1
            cur_voxels = int(cfg_model.num_voxels / (2**n_rest_scales))
            if isinstance(model, tineuvox.TiNeuVox):
                model.scale_volume_grid(cur_voxels)
            else:
                raise NotImplementedError
            optimizer = utils.create_optimizer_or_freeze_model(model, cfg_train, global_step=0)

        # random sample rays
        if cfg_train.ray_sampler in ['flatten', 'in_maskcache']:
            sel_i = batch_index_sampler()
            target = rgb_tr[sel_i]
            target_alpha_inv_last = 1 - masks_tr[sel_i]

            sel_ray = pix_to_ray[sel_i].long()
            rays_o = rays_o_tr[sel_ray]
            rays_d = rays_d_tr[sel_ray]
            viewdirs = viewdirs_tr[sel_ray]
            times_sel = times_flatten[sel_i]

        # elif cfg_train.ray_sampler == 'random':
        #     sel_b = torch.randint(rgb_tr.shape[0], [cfg_train.N_rand])
        #     sel_r = torch.randint(rgb_tr.shape[1], [cfg_train.N_rand])
        #     sel_c = torch.randint(rgb_tr.shape[2], [cfg_train.N_rand])
        #     target = rgb_tr[sel_b, sel_r, sel_c]
        #     rays_o = rays_o_tr[sel_b, sel_r, sel_c]
        #     rays_d = rays_d_tr[sel_b, sel_r, sel_c]
        #     viewdirs = viewdirs_tr[sel_b, sel_r, sel_c]
        #     times_sel = times_flatten[sel_b, sel_r, sel_c]
        else:
            raise NotImplementedError

        u_N = None
        if cfg_train.unobserved_view_reg:
            fraction = 0.05
            u_N = int(fraction * cfg_train.N_rand)

            # rotation around center (volume center)
            rot = roma.random_rotmat(1).to('cpu')
            rays_o[-u_N:] = (rays_o[-u_N:] @ rot)[0]
            rays_d[-u_N:] = (rays_d[-u_N:] @ rot)[0]

        if cfg.data.load2gpu_on_the_fly:
            target = target.to(device)
            rays_o = rays_o.to(device)
            rays_d = rays_d.to(device)
            viewdirs = viewdirs.to(device)
            times_sel = times_sel.to(device)
            target_alpha_inv_last = target_alpha_inv_last.to(device)

        # volume rendering
        render_result = model(rays_o, rays_d, viewdirs, times_sel, global_step=global_step, **render_kwargs)

        # gradient descent step
        optimizer.zero_grad(set_to_none = True)
        if u_N is None:
            loss = cfg_train.weight_main * F.mse_loss(render_result['rgb_marched'], target)
        else:
            loss = cfg_train.weight_main * F.mse_loss(render_result['rgb_marched'][:-u_N], target[:-u_N])
        psnr = utils.mse2psnr(loss.detach())
        
        if cfg_train.weight_entropy_last > 0:
            pout = render_result['alphainv_last'].clamp(1e-6, 1-1e-6)
            entropy_last_loss = -(pout*torch.log(pout) + (1-pout)*torch.log(1-pout)).mean()
            loss += cfg_train.weight_entropy_last * entropy_last_loss

        if cfg_train.weight_mask_loss > 0:
            pout = render_result['alphainv_last'].clamp(1e-6, 1-1e-6).unsqueeze(-1)
            mask_loss = cfg_train.weight_mask_loss * F.binary_cross_entropy(pout, target_alpha_inv_last)
            mask_loss_lst.append(mask_loss.item())
            loss += mask_loss

        if cfg_train.weight_rgbper > 0:
            if u_N is None:
                rgbper = (render_result['raw_rgb'] - target[render_result['ray_id']]).pow(2).sum(-1)
                rgbper_loss = (rgbper * render_result['weights'].detach()).sum() / len(rays_o)
            else:
                rgbper = (render_result['raw_rgb'][:-u_N] - target[render_result['ray_id']][:-u_N]).pow(2).sum(-1)
                rgbper_loss = (rgbper * render_result['weights'][:-u_N].detach()).sum() / len(rays_o)
            
            loss += cfg_train.weight_rgbper * rgbper_loss

        if cfg_train.weight_distortion > 0:
            n_max = render_result['n_max']
            s = render_result['s']
            w = render_result['weights']
            ray_id = render_result['ray_id']
            loss_distortion = cfg_train.weight_distortion * flatten_eff_distloss(w, s, 1/n_max, ray_id)
            dist_loss_lst.append(loss_distortion.item())
            loss +=  loss_distortion
        
        if cfg_train.weight_inv > 0:
            if len(render_result['ray_pts_hat']) > 0:
                loss_inv = cfg_train.weight_inv * torch.sqrt((render_result['ray_pts'] - render_result['ray_pts_hat'])**2 + 1e-6).mean()
                inv_loss_lst.append(loss_inv.item())
                loss += loss_inv
        
        # if cfg_train.weight_delta > 0:
        #     mask = ~render_result['time_mask']
        #     if mask.any():
        #         loss_delta = cfg_train.weight_delta * (render_result['ray_pts_delta'][mask]**2 + 1e-6).sqrt().mean()
        #         delta_loss_lst.append(loss_delta.item())
        #         loss += loss_delta

        loss.backward()

        if global_step<cfg_train.tv_before and global_step>cfg_train.tv_after and global_step%cfg_train.tv_every==0:
            if cfg_train.weight_tv_feature>0:
                model.feature_total_variation_add_grad(
                    cfg_train.weight_tv_feature/len(rays_o), global_step<cfg_train.tv_feature_before)
        optimizer.step()
        psnr_lst.append(psnr.item())
        # update lr
        decay_steps = cfg_train.lrate_decay * 1000
        decay_factor = 0.1 ** (1/decay_steps)
        for i_opt_g, param_group in enumerate(optimizer.param_groups):
            param_group['lr'] = param_group['lr'] * decay_factor

        # check log & save
        if global_step%args.i_print == 0:
            eps_time = time.time() - time0
            eps_time_str = f'{eps_time//3600:02.0f}:{eps_time//60%60:02.0f}:{eps_time%60:02.0f}'
            tqdm.write(f'scene_rep_reconstruction : iter {global_step:6d} / '
                       f'Loss: {loss.item():.9f} / PSNR: {np.mean(psnr_lst):5.2f} / '
                       f'Dist: {np.mean(dist_loss_lst):.9f} / Inv: {np.mean(inv_loss_lst):.9f} / '
                       f'Delta: {np.mean(delta_loss_lst):.9f} / Mask Loss: {np.mean(mask_loss_lst):.9f} / Eps: {eps_time_str}')
            psnr_lst = []
            dist_loss_lst = []
            inv_loss_lst = []
            delta_loss_lst = []
            mask_loss_lst = []

        # if global_step%(args.fre_test) == 0:
        #     render_viewpoints_kwargs = {
        #         'model': model,
        #         'ndc': cfg.data.ndc,
        #         'inverse_y': cfg.data.inverse_y, 
        #         'flip_x': cfg.data.flip_x, 
        #         'flip_y': cfg.data.flip_y,
        #         'render_kwargs': {
        #             'near': near,
        #             'far': far,
        #             'bg': 1 if cfg.data.white_bkgd else 0,
        #             'stepsize': cfg_model.stepsize,

        #             },
        #         }
        #     testsavedir = os.path.join(cfg.basedir, cfg.expname, f'{global_step}-test')
        #     if os.path.exists(testsavedir) == False:
        #         os.makedirs(testsavedir)

        #     rgbs, disps = render_viewpoints(
        #         render_poses=data_dict['poses'][data_dict['i_test']],
        #         HW=data_dict['HW'][data_dict['i_test']],
        #         Ks=data_dict['Ks'][data_dict['i_test']],
        #         gt_imgs=[data_dict['images'][i].cpu().numpy() for i in data_dict['i_test']],
        #         savedir=testsavedir,
        #         test_times=data_dict['times'][data_dict['i_test']],
        #         eval_psnr=args.eval_psnr, eval_ssim=args.eval_ssim, eval_lpips_alex=args.eval_lpips_alex, eval_lpips_vgg=args.eval_lpips_vgg,
        #         **render_viewpoints_kwargs)


    if global_step != -1:
        torch.save({
            'global_step': global_step,
            'model_kwargs': model.get_kwargs(),
            'model_state_dict': model.state_dict(),
        }, last_ckpt_path)
        print('scene_rep_reconstruction : saved checkpoints at', last_ckpt_path)


def train(args, cfg, read_path, save_path, data_dict=None, stages=[1,2]):
    # init
    print('train: start')
    tensorboard_path = os.path.join("./logs/tensorboard", save_path)
    # folder_path = os.path.join(save_path)
    os.makedirs(save_path, exist_ok=True)
    with open(os.path.join(save_path, 'args.txt'), 'w') as file:
        for arg in sorted(vars(args)):
            attr = getattr(args, arg)
            file.write('{} = {}\n'.format(arg, attr))
    cfg.dump(os.path.join(save_path, 'config.py'))
    
    # coarse geometry searching
    xyz_min, xyz_max = compute_bbox_by_cam_frustrm(args = args, cfg = cfg, **data_dict)
    coarse_ckpt_path = None
    
    if 1 in stages:
        # fine detail reconstruction
        eps_time = time.time()
        scene_rep_reconstruction(
                args=args, cfg=cfg,
                cfg_model=cfg.model_and_render, cfg_train=cfg.train_config,
                xyz_min=xyz_min, xyz_max=xyz_max,
                data_dict=data_dict)
        eps_loop = time.time() - eps_time
        eps_time_str = f'{eps_loop//3600:02.0f}:{eps_loop//60%60:02.0f}:{eps_loop%60:02.0f}'
        print('train: finish (eps time', eps_time_str, ')')
    
    if 2 in stages:
        # Export point clouds
        ckpt_path = os.path.join(cfg.basedir, cfg.expname, 'fine_last.tar')
        model = utils.load_model(tineuvox.TiNeuVox, ckpt_path).to(device)
        stepsize = cfg.model_and_render.stepsize
        render_viewpoints_kwargs = {
            'model': model,
            'ndc': cfg.data.ndc,
            'inverse_y': cfg.data.inverse_y, 
            'flip_x': cfg.data.flip_x, 
            'flip_y': cfg.data.flip_y,
            'render_kwargs': {
                'near': data_dict['near'],
                'far': data_dict['far'],
                'bg': 1 if cfg.data.white_bkgd else 0,
                'stepsize': stepsize,
                'render_depth': True,
            },
        }
        bone_length = cfg.pcd_model_and_render.bone_length
        canonical_pcd_num = cfg.pcd_model_and_render.canonical_pcd_num
        other_pcd_num = cfg.pcd_model_and_render.other_pcd_num
        pcd_density_threshold = cfg.pcd_model_and_render.pcd_density_threshold

        # determine actual canonical t
        canonical_t_indx = torch.argmin(((data_dict['times'].float() - cfg.data.canonical_t)**2).sqrt()).long()
        canonical_t = data_dict['times'].float()[canonical_t_indx].item()
        export_point_cloud(model, data_dict, read_path, render_viewpoints_kwargs, canonical_t, pcd_density_threshold, 
            export='both', bone_length=bone_length, canonical_pcd_num=canonical_pcd_num, other_pcd_num=other_pcd_num)

        torch.cuda.empty_cache()
        # train point cloud reconstruction
        eps_time = time.time()
        train_pcd(
            args=args, cfg=cfg, 
            cfg_model=cfg.pcd_model_and_render, cfg_train=cfg.pcd_train_config, 
            read_path=read_path,save_path=save_path, data_dict=data_dict, tineuvox_model=model, canonical_t=canonical_t, tensorboard_path=tensorboard_path)
        eps_loop = time.time() - eps_time
        eps_time_str = f'{eps_loop//3600:02.0f}:{eps_loop//60%60:02.0f}:{eps_loop%60:02.0f}'
        print('train: finish (eps time', eps_time_str, ')')

def export_point_cloud(model, data_dict, path, render_viewpoints_kwargs, canonical_t=0., threshold=0.2, export='torch', bone_length=4., canonical_pcd_num=3e+4, other_pcd_num=5e+3):
    import open3d as o3d

    folder_path = os.path.join(path, 'pcds')
    try:
        os.makedirs(folder_path)
    except:
        print('PCD folder already exists, skipping extraction')
        return

    def save_pcd(pcd, rgbs, feat, raw_feat, alphas, pcd_path, t, export_torch):
        if export_torch:
            torch.save({
                'pcd': pcd,
                'rgbs': rgbs,
                'feat': feat,
                'raw_feat': raw_feat,
                'alphas': alphas,
                't': t.item()
            }, pcd_path)
        else:
            pcd_o3d = o3d.geometry.PointCloud()
            pcd_o3d.points = o3d.utility.Vector3dVector(pcd.cpu().numpy())
            pcd_o3d.colors = o3d.utility.Vector3dVector(rgbs.cpu().numpy())
            o3d.io.write_point_cloud(pcd_path, pcd_o3d)

    render_poses = data_dict['poses'][data_dict['img_to_cam'][data_dict['i_train']]].float()
    Ks = data_dict['Ks'][data_dict['img_to_cam'][data_dict['i_train']]]
    HW = data_dict['HW'][data_dict['i_train']]
    times = data_dict['times'].float()

    sorted_indices = torch.argsort(times)
    render_poses = render_poses[sorted_indices]
    HW = HW[sorted_indices.cpu().numpy()]
    Ks = Ks[sorted_indices.cpu().numpy()]
    times = times[sorted_indices]
    canonical_t = torch.tensor(canonical_t)

    processed_times = []
    path_dict = {}

    xyz_min = torch.tensor([np.inf, np.inf, np.inf])
    xyz_max = -xyz_min

    canonical_sampling_freq = None
    other_sampling_freq = None

    for i in tqdm(range(0, len(times), 1)):
        # Get time
        t = times[i]
        if t in processed_times:
            continue
        processed_times.append(t)
        t = t.unsqueeze(0)

        # Get render parameters and get point cloud
        c2w = render_poses[i]
        H, W = HW[i]
        K = Ks[i]
        _, _, viewdirs = tineuvox.get_rays_of_a_view(
                H, W, K, c2w, render_viewpoints_kwargs['ndc'], 
                inverse_y=render_viewpoints_kwargs['inverse_y'], flip_x=render_viewpoints_kwargs['flip_x'], flip_y=render_viewpoints_kwargs['flip_y'])
        viewdir = viewdirs.mean(dim=0).mean(dim=0).reshape((1, 3))
        stepsize = render_viewpoints_kwargs['render_kwargs']['stepsize']

        if i == 0:
            points, _, _, _, _, _, _ = model.get_grid_as_point_cloud(stepsize=stepsize, time_sel=t, viewdir=viewdir, threshold=threshold, sampling_freq=1, N_batch=2**21)
            other_sampling_freq = (other_pcd_num/len(points)) ** (1/3) # Naturally, other time points can give different number of points, just a rough approximation
            canonical_sampling_freq = (canonical_pcd_num/len(points)) ** (1/3)

        is_canonical = t == t == canonical_t
        sampling_freq = canonical_sampling_freq if is_canonical else other_sampling_freq        
        points, alphas, rgbs, feat, raw_feat, binary_volume, grid_xyz = model.get_grid_as_point_cloud(stepsize=stepsize, time_sel=t, viewdir=viewdir, threshold=threshold, sampling_freq=sampling_freq, N_batch=2**21)

        if is_canonical:
            binary_volume = binary_volume.cpu().numpy()
            binary_volume = remove_small_holes(binary_volume.astype(bool), area_threshold=2**8,)
            binary_volume = largest_k(binary_volume, connectivity=26, k=1).astype(int)
            binary_volume = torch.tensor(binary_volume, dtype=torch.bool)
            points, alphas, rgbs, feat, raw_feat, binary_volume, grid_xyz = model.get_grid_as_point_cloud(stepsize=stepsize, time_sel=t, viewdir=viewdir, threshold=threshold, sampling_freq=sampling_freq, N_batch=2**21, blob_mask=binary_volume)
            canonical_sampling_freq = (canonical_pcd_num/len(points)) ** (1/3)
            # points, alphas, rgbs, feat, raw_feat, binary_volume, grid_xyz = model.get_grid_as_point_cloud(stepsize=stepsize, time_sel=t, viewdir=viewdir, threshold=threshold, sampling_freq=sampling_freq, N_batch=2**21, bl)

        # Save data
        pcd_torch_path  = os.path.join(folder_path, f'{t.item()}.tar')
        pcd_path  = os.path.join(folder_path, f'{t.item()}.pcd')
        save_pcd(points, rgbs, feat, raw_feat, alphas, pcd_torch_path, t, export_torch=(export=='torch' or export=='both'))
        path_dict[t.item()] = pcd_torch_path

        # Find min & max for bbox
        xyz_min = torch.minimum(xyz_min, points.min(dim=0)[0])
        xyz_max = torch.maximum(xyz_max, points.max(dim=0)[0])

        if t  == canonical_t: # export skeleton
            save_pcd(points, rgbs, feat, raw_feat, alphas, pcd_path, t, export_torch= not (export=='o3d' or export=='both'))
            torch.save({
                'grid_xyz': grid_xyz.detach().cpu(),
                'binary_volume': binary_volume.detach().cpu()
            }, 'skeleton.tar')
            grid_xyz = grid_xyz.detach().cpu().numpy()
            binary_volume = binary_volume.detach().cpu().numpy()
            res = create_skeleton(binary_volume, grid_xyz, bone_length=bone_length)
            pcd_torch_path = os.path.join(folder_path, f'skeleton_{t.item()}.tar')
            torch.save(res, pcd_torch_path)
            print(f"{len(res['bones'])} bones extracted.")
    
    # Add bbox to json and save

    path_dict['xyz_min'] = str(list(xyz_min.cpu().numpy()))
    path_dict['xyz_max'] = str(list(xyz_max.cpu().numpy()))
    path_dict['voxel_size'] = str(model.voxel_size.cpu().numpy())

    with open(os.path.join(path, 'pcd_paths.json'), "w") as fp:
        json.dump(path_dict, fp) 

if __name__=='__main__':
    # load setup
    parser = config_parser()
    args = parser.parse_args()
    cfg = mmcv.Config.fromfile(args.config)
    # init enviroment
    if torch.cuda.is_available():
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')
    seed_everything()
    data_dict = None


    # load images / poses / camera settings / data split
    if not args.visualise_warp and not args.debug_bone_merging: # Some functions do not have to load the data set
        data_dict = load_everything(args = args, cfg = cfg, use_cache=args.use_cache, overwrite=args.overwrite_cache)

    if args.ablation_tag is not None:

        read_path = os.path.join(cfg.basedir, cfg.expname)
        save_path = os.path.join(cfg.basedir, cfg.expname, 'ablation', args.ablation_tag)

        # cfg.expname += f'_{args.ablation_tag}'
        valid_embeddings = ['full', 'raw_embed', 'raw', 'rot']
        # if not ((args.ablation_tag in cfg['pcd_train_config'].keys()) or (args.ablation_tag in valid_embeddings)):
        #     raise ValueError(f"ablation_tag {args.ablation_tag} not found in config")

        if args.ablation_tag in valid_embeddings:
            cfg['pcd_train_config']['embedding'] = args.ablation_tag
            print(f'Embedding: {args.ablation_tag}')
        elif args.ablation_tag in cfg['pcd_train_config'].keys():
            cfg['pcd_train_config'][args.ablation_tag] = 0.
            print(f'Ablation tag: {args.ablation_tag} set to 0.0')
        else:
            raise ValueError(f"ablation_tag {args.ablation_tag} not found in config")
        
    else:
        read_path = os.path.join(cfg.basedir, cfg.expname)
        save_path = read_path

    # train
    if not args.render_only:
        if args.first_stage_only:
            stages = [1]
        elif args.second_stage_only:
            stages = [2]
        else:
            stages = [1,2]
        train(args, cfg, read_path, save_path, data_dict = data_dict, stages=stages)

    # load model for rendring
    if args.render_test or args.render_train or args.render_video or args.joint_placement or args.benchmark or args.repose_pcd or args.test or args.visualise_canonical:
        
        cfg.basedir += args.basedir_append_suffix

        if not args.render_pcd:
            ckpt_path = os.path.join(cfg.basedir, cfg.expname, 'fine_last.tar')
            model_class = tineuvox.TiNeuVox
        else:
            ckpt_path = os.path.join(read_path, 'temporalpoints_last.tar')
            model_class = temporalpoints.TemporalPoints
        
        model = utils.load_model(model_class, ckpt_path).to(device)
        ckpt_name = ckpt_path.split('/')[-1][:-4]
        near=data_dict['near']
        far=data_dict['far']
        stepsize = cfg.model_and_render.stepsize
        render_viewpoints_kwargs = {
            'model': model,
            'ndc': cfg.data.ndc,
            'inverse_y': cfg.data.inverse_y, 
            'flip_x': cfg.data.flip_x, 
            'flip_y': cfg.data.flip_y,
            'render_kwargs': {
                'near': near,
                'far': far,
                'bg': 1 if cfg.data.white_bkgd else 0,
                'stepsize': stepsize,
                'render_depth': True,
            },
        }

        if args.degree_threshold > 0:
            times = torch.linspace(0., 1., 300).unsqueeze(-1)
            out = model.simplify_skeleton(times, deg_threshold=args.degree_threshold, five_percent_heuristic=True, visualise_canonical=False)
            res = out[-1]
    
    if args.test:
        import progressive_eval

        save_dir = cfg.basedir
        save_dir = ckpt_path = os.path.join(save_dir, cfg.expname, 'test_results')
        ckpt_dir = ckpt_path = os.path.join(cfg.basedir, cfg.expname, 'benchmark_weights')
        progressive_eval.test(data_dict, render_viewpoints_kwargs, save_dir, ckpt_dir, mode="test")
        # (data_dict: dict, stepsize: int, render_viewpoints_kwargs: dict, save_dir: str, ckpt_dir: str, mode: str = "test"

    # render trainset and eval
    if args.render_train:
        testsavedir = os.path.join(save_path, f'render_train_{ckpt_name}')
        os.makedirs(testsavedir, exist_ok = True)

        rgbs, disps, weights = render_viewpoints(
                render_poses=data_dict['poses'][data_dict['img_to_cam'][data_dict['i_train']]],
                HW=data_dict['HW'][data_dict['i_train']],
                Ks=data_dict['Ks'][data_dict['img_to_cam'][data_dict['i_train']]],
                gt_imgs=[data_dict['images'][i].cpu().numpy() for i in data_dict['i_train']],
                savedir=testsavedir,
                test_times=data_dict['times'][data_dict['i_train']],
                eval_psnr=args.eval_psnr, eval_ssim=args.eval_ssim, eval_lpips_alex=args.eval_lpips_alex, eval_lpips_vgg=args.eval_lpips_vgg,
                **render_viewpoints_kwargs)

        imageio.mimwrite(os.path.join(testsavedir, 'train_video.rgb.mp4'), utils.to8b(rgbs), fps = 30, quality = 8)
        imageio.mimwrite(os.path.join(testsavedir, 'train_video.disp.mp4'), utils.to8b(disps / np.max(disps)), fps = 30, quality = 8)
        if len(weights) > 0:
            imageio.mimwrite(os.path.join(testsavedir, 'video.weights.mp4'), utils.to8b(weights), fps=30, quality=8)

    # render testset and eval
    if args.render_test:
        testsavedir = os.path.join(save_path, f'render_test_{ckpt_name}')
        os.makedirs(testsavedir, exist_ok=True)
        rgbs, disps, _, _ = render_viewpoints(
                render_poses=data_dict['poses'][data_dict['i_test']],
                HW=data_dict['HW'][data_dict['i_test']],
                Ks=data_dict['Ks'][data_dict['img_to_cam'][data_dict['i_test']]],
                gt_imgs=[data_dict['images'][i].cpu().numpy() for i in data_dict['i_test']],
                savedir=testsavedir,
                test_times=data_dict['times'][data_dict['i_test']],
                eval_psnr=args.eval_psnr,eval_ssim = args.eval_ssim, eval_lpips_alex=args.eval_lpips_alex, eval_lpips_vgg=args.eval_lpips_vgg,
                **render_viewpoints_kwargs)


        imageio.mimwrite(os.path.join(testsavedir, 'test_video.rgb.mp4'), utils.to8b(rgbs), fps=30, quality=8)
        imageio.mimwrite(os.path.join(testsavedir, 'test_video.disp.mp4'), utils.to8b(disps / np.max(disps)), fps=30, quality=8)
        

    # render video
    if args.render_video or args.benchmark:

        if cfg.data.dataset_type  != 'hyper_dataset':
            if args.render_video:
                testsavedir = os.path.join(save_path, f'render_video_{ckpt_name}_time')
                os.makedirs(testsavedir, exist_ok=True)
            else:
                testsavedir = None
            rgbs, disps, weights, flows = render_viewpoints(
                    # render_poses=data_dict['render_poses'][0][None,...].repeat(len(data_dict['render_poses']), 1, 1),
                    render_poses=data_dict['render_poses'],
                    HW=data_dict['HW'][0][None,...].repeat(len(data_dict['render_poses']), 0),
                    Ks=data_dict['Ks'][0][None,...].repeat(len(data_dict['render_poses']), 0),
                    render_factor=args.render_video_factor,
                    savedir=testsavedir,
                    test_times=data_dict['render_times'],
                    benchmark=args.benchmark,
                    render_pcd_direct=args.render_pcd_direct,
                    kinematic_warp=cfg['pcd_train_config']['kinematic_warp'],
                    **render_viewpoints_kwargs)
            if args.render_video:
                imageio.mimwrite(os.path.join(testsavedir, 'video.rgb.mp4'), utils.to8b(rgbs), fps=30, quality=8)
                imageio.mimwrite(os.path.join(testsavedir, 'video.disp.mp4'), utils.to8b(disps / np.max(disps)), fps=30, quality =8)
                if len(flows):
                    imageio.mimwrite(os.path.join(testsavedir, 'video.flows.mp4'), utils.to8b(flows / np.max(flows)), fps=30, quality =8)
                if len(weights) > 0:
                    imageio.mimwrite(os.path.join(testsavedir, 'video.weights.mp4'), utils.to8b(weights), fps=30, quality=8)
        else:
            raise NotImplementedError
        
    def save_pcd(points, colors, pcd_path):
        pcd = o3d.geometry.PointCloud()
        pcd.points = o3d.utility.Vector3dVector(points)
        pcd.colors = o3d.utility.Vector3dVector(colors)
        o3d.io.write_point_cloud(pcd_path, pcd)

    if args.repose_pcd:
        # import seaborn as sns
        import open3d as o3d

        # ckpt_path = os.path.join(cfg.basedir, cfg.expname, 'temporalpoints_last.tar')
        # model_class = temporalpoints.TemporalPoints
        # model = utils.load_model(model_class, ckpt_path).to(device)
        model = render_viewpoints_kwargs['model']


        bones = model.bones
        joints = model.joints.detach().cpu().numpy()
        weights = model.get_weights().detach().cpu().numpy()

        print(dict(zip(range(len(res)), res.cpu().numpy())))
        start_scale = 0
        end_scale = 1
        step = len(data_dict['render_poses'])

        target_params = torch.randn((len(joints), 4))
        target_params /= torch.norm(target_params, dim=-1, keepdim=True) 

        ### Jumping jacks, threshold 20
        target_params = torch.zeros_like(target_params)
        # arm 1 
        target_params[12] = torch.tensor([0., -1., 0.3, 2.])
        target_params[17] = torch.tensor([1., 1., 0., 1])
        # arm 2
        target_params[13] = torch.tensor([0., 1., 0., 2])
        #legs
        target_params[10] = torch.tensor([0., -1., 0., -1.])  
        target_params[9] = torch.tensor([0., -1., 0., -0.2])  

        #legs
        target_params[19] = torch.tensor([1., 0., 0., 2])  
        target_params[20] = torch.tensor([1., 0., 0., 2])  


        ### Spot, threshold 30 
        # target_params = torch.zeros_like(target_params)
        # target_params[5] = torch.tensor([0., 1., 0, 2.])
        # target_params[17] = torch.tensor([0., -1., 0, 1])

        # target_params[6] = torch.tensor([0., 1., 0, 2.])
        # target_params[18] = torch.tensor([0., -1., 0, 1])

        # target_params[11] = torch.tensor([0., -1., 0, 2.])
        # target_params[26] = torch.tensor([0., -1., 0, 1])

        # target_params[12] = torch.tensor([0., -1., 0, 2.])
        # target_params[19] = torch.tensor([0., -1., 0, 1])
        
        
        

        # rot_params = torch.zeros((len(joints), 3))
        target_params[0] = 0.
        # zero_mask = torch.tensor([i for i in range(len(target_params)) if i not in keep_indices])
        # target_params[zero_mask] = 0.
        
        target_params = target_params[None] * torch.linspace(start_scale, end_scale, step)[:,None,None]

        
        

        # with torch.no_grad():
        #     points = model.repose(target_params).detach().cpu().numpy()
        
        # np.random.seed(args.seed)
        # col_palette = np.random.rand(len(joints), 3)

        # col_skeleton = np.array([[1.,0.,0.]] * len(joints)).reshape(len(joints), 3)
        # col_bones = np.array([[1.,0.,0.]] * len(bones)).reshape(len(bones), 3)
        # points = np.append(points, joints, axis=0)

        # rgbs = (np.expand_dims(weights, axis=-1) * col_palette).sum(axis=1)
        # rgbs = np.append(rgbs, col_skeleton, axis=0)

        # pcd = o3d.geometry.PointCloud()
        # pcd.points = o3d.utility.Vector3dVector(points)
        # pcd.colors = o3d.utility.Vector3dVector(rgbs)

        # line_set = o3d.geometry.LineSet()
        # line_set.points = o3d.utility.Vector3dVector(joints)
        # line_set.lines = o3d.utility.Vector2iVector(bones)
        # line_set.colors = o3d.utility.Vector3dVector(col_bones)
        # o3d.visualization.draw_geometries([pcd, line_set])

        testsavedir = os.path.join(cfg.basedir, cfg.expname, f'render_video_repose_{args.seed}')
        os.makedirs(testsavedir, exist_ok=True)

        rgbs, disps, weights = render_repose(
                    render_poses=data_dict['render_poses'][0][None,...].repeat(step, 1, 1),
                    HW=data_dict['HW'][0][None,...].repeat(step, 0),
                    Ks=data_dict['Ks'][0][None,...].repeat(step, 0),
                    render_factor=args.render_video_factor,
                    savedir=testsavedir,
                rot_params=target_params,
                eval_psnr=args.eval_psnr, eval_ssim=args.eval_ssim, eval_lpips_alex=args.eval_lpips_alex, eval_lpips_vgg=args.eval_lpips_vgg,
                **render_viewpoints_kwargs)

        imageio.mimwrite(os.path.join(testsavedir, 'train_video.rgb.mp4'), utils.to8b(rgbs), fps = 30, quality = 8)
        imageio.mimwrite(os.path.join(testsavedir, 'train_video.disp.mp4'), utils.to8b(disps / np.max(disps)), fps = 30, quality = 8)
        if len(weights) > 0:
            imageio.mimwrite(os.path.join(testsavedir, 'video.weights.mp4'), utils.to8b(weights), fps=30, quality=8)

    if args.visualise_warp:
        import open3d as o3d
        warped_savedir = os.path.join(cfg.basedir, cfg.expname, 'warped_pcds')
        pcd_path = os.path.join(cfg.basedir, cfg.expname, 'pcds')
        os.makedirs(warped_savedir, exist_ok=True)

        ckpt_path = os.path.join(cfg.basedir, cfg.expname, 'temporalpoints_last.tar')
        ckpt_name = ckpt_path.split('/')[-1][:-4]

        model_class = temporalpoints.TemporalPoints
        model = utils.load_model(model_class, ckpt_path).to(device)

        times = torch.linspace(0., 1., 300).unsqueeze(-1)
        _ = model.simplify_skeleton(times, deg_threshold=15, five_percent_heuristic=False)

        xyz        = model.canonical_pcd
        weights    = model.get_weights()

        xyz = xyz.cpu().numpy()

        points = xyz

        cols = torch.randn((weights.shape[-1], 3))
        cols_weights = (weights.unsqueeze(-1) * cols).sum(dim=1).detach().cpu().numpy()
        # cols_weights_argmax = (torch.nn.functional.one_hot(torch.argmax(weights, dim=-1), num_classes=weights.shape[1]).unsqueeze(-1) * cols).sum(dim=1).detach().cpu().numpy()

        pcds = os.listdir(pcd_path)
        filter_func = lambda pcd: ('0.0.tar' not in pcd) and ('.pcd' not in pcd) and ('skeleton' not in pcd)
        pcds = [pcd for pcd in pcds if filter_func(pcd)]

        with torch.no_grad():
            for pcd in pcds:
                t = torch.tensor([float(pcd.replace('.tar', ''))])
                warped_points = model(t)['t_hat_pcd']
                warped_points = warped_points.cpu().numpy()
                save_pcd(warped_points, cols_weights, os.path.join(warped_savedir, f'{t.item()}.pcd'))

    if args.visualise_canonical:
        from skeletonizer import visualise_skeletonizer

    
        save_path = os.path.join(cfg.basedir, cfg.expname, 'canonical.pcd')

        threshold = args.degree_threshold

        skeleton_points = model.skeleton_pcd.detach().cpu().numpy()
        weights = model.get_weights().detach().cpu().numpy()
        root = model.joints[0].detach().cpu().numpy()
        joints = model.joints.detach().cpu().numpy()
        bones = model.bones
        canonical_pcd = model.canonical_pcd.detach().cpu().numpy()

        # torch.save({
        #     'skeleton_points': model.skeleton_pcd,
        #     'weights': model.get_weights(),
        #     'root': model.joints[0],
        #     'joints': model.joints,
        #     'bones': model.bones,
        #     'canonical_pcd': model.canonical_pcd,
        # }, f'./skeleton_save_{threshold}.tar')
    
        visualise_skeletonizer(skeleton_points, root, joints, bones, canonical_pcd, weights, save=False, save_path=save_path)

    if args.export_bbox_and_cams_only:
        import sys
        print('Export bbox and cameras...')
        xyz_min, xyz_max = compute_bbox_by_cam_frustrm(args=args, cfg=cfg, **data_dict)
        poses, HW, Ks, i_train = data_dict['poses'], data_dict['HW'], data_dict['Ks'], data_dict['i_train']
        near, far = data_dict['near'], data_dict['far']

        cam_lst = []
        for c2w, (H, W), K in zip(poses, HW[torch.arange(len(poses)).cpu()], Ks):
            rays_o, rays_d, viewdirs = tineuvox.get_rays_of_a_view(
                    H, W, K, c2w, cfg.data.ndc, flip_x=cfg.data.flip_x, flip_y=cfg.data.flip_y, inverse_y=cfg.data.inverse_y)
            cam_o = rays_o[0,0].cpu().numpy()
            cam_d = rays_d[[0,0,-1,-1],[0,-1,0,-1]].cpu().numpy()
            cam_lst.append(np.array([cam_o, *(cam_o+cam_d*max(near, far*0.05))]))
        np.savez_compressed(args.export_bbox_and_cams_only,
            xyz_min=xyz_min.cpu().numpy(), xyz_max=xyz_max.cpu().numpy(),
            cam_lst=np.array(cam_lst))
        print('done')
        sys.exit()

    if args.debug_bone_merging:
        from skeletonizer import visualise_skeletonizer
        from lib.treeprune import visualise_merging

        ckpt_path = os.path.join(cfg.basedir, cfg.expname, 'temporalpoints_last.tar')
        ckpt_name = ckpt_path.split('/')[-1][:-4]

        model_class = temporalpoints.TemporalPoints
        model = utils.load_model(model_class, ckpt_path).to(device)


        # Show initial
        times = torch.linspace(0., 1., 300).unsqueeze(-1)
        skeleton_points = model.skeleton_pcd.detach().cpu().numpy()
        weights = model.get_weights().detach().cpu().numpy()
        root = model.joints[0].detach().cpu().numpy()
        joints = model.joints.detach().cpu().numpy()
        bones = model.bones
        canonical_pcd = model.canonical_pcd.detach().cpu().numpy()

        visualise_skeletonizer(skeleton_points, root, joints, bones, canonical_pcd, weights, save=True, save_path='./bone_merge.pcd')

        for i in range(10):
            joints, bones, new_joints, new_bones, prune, merging_rules, rotations_to_keep = model.simplify_skeleton(times, deg_threshold=15, five_percent_heuristic=False)
            # visualise_merging(joints, bones, new_joints, new_bones, prune, merging_rules)
            print(f'Iteration {i+1}')

            skeleton_points = model.skeleton_pcd.detach().cpu().numpy()
            weights = model.get_weights().detach().cpu().numpy()
            root = model.joints[0].detach().cpu().numpy()
            joints = model.joints.detach().cpu().numpy()
            bones = model.bones
            canonical_pcd = model.canonical_pcd.detach().cpu().numpy()

            visualise_skeletonizer(skeleton_points, root, joints, bones, canonical_pcd, weights, save=True, save_path='./bone_merge.pcd')

