import os
import cv2
import glob
import json
from cv2 import transform
import tqdm
import numpy as np
import random
import gzip
from scipy.spatial.transform import Slerp, Rotation
import os.path as osp

import trimesh

import torch
from torch.utils.data import DataLoader,Dataset

from .utils import get_rays

class_dict = {
    'car': '02958343',
    'chair': '03001627'
}


# ref: https://github.com/NVlabs/instant-ngp/blob/b76004c8cf478880227401ae763be4c02f80b62f/include/neural-graphics-primitives/nerf_loader.h#L50
def nerf_matrix_to_ngp(pose, scale=0.33, offset=[0, 0, 0]):
    # for the fox dataset, 0.33 scales camera radius to ~ 2
    new_pose = np.array([
        [pose[1, 0], -pose[1, 1], -pose[1, 2], pose[1, 3] * scale + offset[0]],
        [pose[2, 0], -pose[2, 1], -pose[2, 2], pose[2, 3] * scale + offset[1]],
        [pose[0, 0], -pose[0, 1], -pose[0, 2], pose[0, 3] * scale + offset[2]],
        [0, 0, 0, 1],
    ], dtype=np.float32)
    return new_pose



def visualize_poses(poses, size=0.1):
    # poses: [B, 4, 4]

    axes = trimesh.creation.axis(axis_length=4)
    box = trimesh.primitives.Box(extents=(2, 2, 2)).as_outline()
    box.colors = np.array([[512, 512, 512]] * len(box.entities))
    objects = [axes, box]

    for pose in poses:
        # a camera is visualized with 8 line segments.
        pos = pose[:3, 3]
        a = pos + size * pose[:3, 0] + size * pose[:3, 1] + size * pose[:3, 2]
        b = pos - size * pose[:3, 0] + size * pose[:3, 1] + size * pose[:3, 2]
        c = pos - size * pose[:3, 0] - size * pose[:3, 1] + size * pose[:3, 2]
        d = pos + size * pose[:3, 0] - size * pose[:3, 1] + size * pose[:3, 2]

        dir = (a + b + c + d) / 4 - pos
        dir = dir / (np.linalg.norm(dir) + 1e-8)
        o = pos + dir * 3

        segs = np.array([[pos, a], [pos, b], [pos, c], [pos, d], [a, b], [b, c], [c, d], [d, a], [pos, o]])
        segs = trimesh.load_path(segs)
        objects.append(segs)

    trimesh.Scene(objects).show()


def rand_poses(size, device, radius=1, theta_range=[np.pi/3, 2*np.pi/3], phi_range=[0, 2*np.pi]):
    ''' generate random poses from an orbit camera
    Args:
        size: batch size of generated poses.
        device: where to allocate the output.
        radius: camera radius
        theta_range: [min, max], should be in [0, \pi]
        phi_range: [min, max], should be in [0, 2\pi]
    Return:
        poses: [size, 4, 4]
    '''
    
    def normalize(vectors):
        return vectors / (torch.norm(vectors, dim=-1, keepdim=True) + 1e-10)

    thetas = torch.rand(size, device=device) * (theta_range[1] - theta_range[0]) + theta_range[0]
    phis = torch.rand(size, device=device) * (phi_range[1] - phi_range[0]) + phi_range[0]

    centers = torch.stack([
        radius * torch.sin(thetas) * torch.sin(phis),
        radius * torch.cos(thetas),
        radius * torch.sin(thetas) * torch.cos(phis),
    ], dim=-1) # [B, 3]

    # lookat
    forward_vector = - normalize(centers)
    up_vector = torch.FloatTensor([0, -1, 0]).to(device).unsqueeze(0).repeat(size, 1) # confused at the coordinate system...
    right_vector = normalize(torch.cross(forward_vector, up_vector, dim=-1))
    up_vector = normalize(torch.cross(right_vector, forward_vector, dim=-1))

    poses = torch.eye(4, dtype=torch.float, device=device).unsqueeze(0).repeat(size, 1, 1)
    poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1)
    poses[:, :3, 3] = centers

    return poses


class NeRFDataset:
    def __init__(self, opt, device, type='train', downscale=1, n_test=10):
        super().__init__()
        
        self.opt = opt
        self.device = device
        self.type = type # train, val, test
        self.downscale = downscale
        self.root_path = opt.path
        self.preload = opt.preload # preload data into GPU
        self.scale = opt.scale # camera radius scale to make sure camera are inside the bounding box.
        self.offset = opt.offset # camera offset
        self.bound = opt.bound # bounding box half length, also used as the radius to random sample poses.
        self.fp16 = opt.fp16 # if preload, load into fp16.

        self.training = self.type in ['train', 'all', 'trainval']
        self.num_rays = self.opt.num_rays if self.training else -1

        self.rand_pose = opt.rand_pose

        # auto-detect transforms.json and split mode.
        if os.path.exists(os.path.join(self.root_path, 'transforms.json')):
            self.mode = 'colmap' # manually split, use view-interpolation for test.
        elif os.path.exists(os.path.join(self.root_path, 'transforms_train.json')):
            self.mode = 'blender' # provided split
        else:
            raise NotImplementedError(f'[NeRFDataset] Cannot find transforms*.json under {self.root_path}')

        # load nerf-compatible format data.
        if self.mode == 'colmap':
            with open(os.path.join(self.root_path, 'transforms.json'), 'r') as f:
                transform = json.load(f)
        elif self.mode == 'blender':
            # load all splits (train/valid/test), this is what instant-ngp in fact does...
            if type == 'all':
                transform_paths = glob.glob(os.path.join(self.root_path, '*.json'))
                transform = None
                for transform_path in transform_paths:
                    with open(transform_path, 'r') as f:
                        tmp_transform = json.load(f)
                        if transform is None:
                            transform = tmp_transform
                        else:
                            transform['frames'].extend(tmp_transform['frames'])
            # load train and val split
            elif type == 'trainval':
                with open(os.path.join(self.root_path, f'transforms_train.json'), 'r') as f:
                    transform = json.load(f)
                with open(os.path.join(self.root_path, f'transforms_val.json'), 'r') as f:
                    transform_val = json.load(f)
                transform['frames'].extend(transform_val['frames'])
            # only load one specified split
            else:
                with open(os.path.join(self.root_path, f'transforms_{type}.json'), 'r') as f:
                    transform = json.load(f)

        else:
            raise NotImplementedError(f'unknown dataset mode: {self.mode}')

        # load image size
        if 'h' in transform and 'w' in transform:
            self.H = int(transform['h']) // downscale
            self.W = int(transform['w']) // downscale
        else:
            # we have to actually read an image to get H and W later.
            self.H = self.W = 512
        
        # read images
        frames = transform["frames"]
        #frames = sorted(frames, key=lambda d: d['file_path']) # why do I sort...
        
        # for colmap, manually interpolate a test set.
        if self.mode == 'colmap' and type == 'test':
            
            # choose two random poses, and interpolate between.
            f0, f1 = np.random.choice(frames, 2, replace=False)
            pose0 = nerf_matrix_to_ngp(np.array(f0['transform_matrix'], dtype=np.float32), scale=self.scale, offset=self.offset) # [4, 4]
            pose1 = nerf_matrix_to_ngp(np.array(f1['transform_matrix'], dtype=np.float32), scale=self.scale, offset=self.offset) # [4, 4]
            rots = Rotation.from_matrix(np.stack([pose0[:3, :3], pose1[:3, :3]]))
            slerp = Slerp([0, 1], rots)

            self.poses = []
            self.images = None
            for i in range(n_test + 1):
                ratio = np.sin(((i / n_test) - 0.5) * np.pi) * 0.5 + 0.5
                pose = np.eye(4, dtype=np.float32)
                pose[:3, :3] = slerp(ratio).as_matrix()
                pose[:3, 3] = (1 - ratio) * pose0[:3, 3] + ratio * pose1[:3, 3]
                self.poses.append(pose)

        else:
            # for colmap, manually split a valid set (the first frame).
            if self.mode == 'colmap':
                if type == 'train':
                    frames = frames[1:]
                elif type == 'val':
                    frames = frames[:1]
                # else 'all' or 'trainval' : use all frames
            
            self.poses = []
            self.images = []
            for f in tqdm.tqdm(frames, desc=f'Loading {type} data'):
                f_path = os.path.join(self.root_path, os.path.basename(f['file_path']))
                if self.mode == 'colmap' and '.' not in os.path.basename(f_path):
                    f_path += '.png' # so silly...

                # there are non-exist paths in fox...
                if not os.path.exists(f_path):
                    continue
                
                pose = np.array(f['transform_matrix'], dtype=np.float32) # [4, 4]
                pose = nerf_matrix_to_ngp(pose, scale=self.scale, offset=self.offset)

                image = cv2.imread(f_path, cv2.IMREAD_UNCHANGED) # [H, W, 3] o [H, W, 4]
                if self.H is None or self.W is None: 
                    self.H = image.shape[0] // downscale
                    self.W = image.shape[1] // downscale

                # add support for the alpha channel as a mask.
                if image.shape[-1] == 3: 
                    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
                else:
                    image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)

                if image.shape[0] != self.H or image.shape[1] != self.W:
                    image = cv2.resize(image, (self.W, self.H), interpolation=cv2.INTER_AREA)
                    
                image = image.astype(np.float32) / 255 # [H, W, 3/4]

                self.poses.append(pose)
                self.images.append(image)
            
        self.poses = torch.from_numpy(np.stack(self.poses, axis=0)) # [N, 4, 4]
        if self.images is not None:
            self.images = torch.from_numpy(np.stack(self.images, axis=0)) # [N, H, W, C]
        
        # calculate mean radius of all camera poses
        self.radius = self.poses[:, :3, 3].norm(dim=-1).mean(0).item()
        #print(f'[INFO] dataset camera poses: radius = {self.radius:.4f}, bound = {self.bound}')

        # initialize error_map
        if self.training and self.opt.error_map:
            self.error_map = torch.ones([self.images.shape[0], 512 * 512], dtype=torch.float) # [B, 512 * 512], flattened for easy indexing, fixed resolution...
        else:
            self.error_map = None

        # [debug] uncomment to view all training poses.
        # visualize_poses(self.poses.numpy())

        # [debug] uncomment to view examples of randomly generated poses.
        # visualize_poses(rand_poses(100, self.device, radius=self.radius).cpu().numpy())

        if self.preload:
            self.poses = self.poses.to(self.device)
            if self.images is not None:
                # TODO: linear use pow, but pow for half is only available for torch >= 1.10 ?
                if self.fp16 and self.opt.color_space != 'linear':
                    dtype = torch.half
                else:
                    dtype = torch.float
                self.images = self.images.to(dtype).to(self.device)
            if self.error_map is not None:
                self.error_map = self.error_map.to(self.device)

        # load intrinsics
        if 'fl_x' in transform or 'fl_y' in transform:
            fl_x = (transform['fl_x'] if 'fl_x' in transform else transform['fl_y']) / downscale
            fl_y = (transform['fl_y'] if 'fl_y' in transform else transform['fl_x']) / downscale
        elif 'camera_angle_x' in transform or 'camera_angle_y' in transform:
            # blender, assert in radians. already downscaled since we use H/W
            fl_x = self.W / (2 * np.tan(transform['camera_angle_x'] / 2)) if 'camera_angle_x' in transform else None
            fl_y = self.H / (2 * np.tan(transform['camera_angle_y'] / 2)) if 'camera_angle_y' in transform else None
            if fl_x is None: fl_x = fl_y
            if fl_y is None: fl_y = fl_x
        else:
            raise RuntimeError('Failed to load focal length, please check the transforms.json!')

        cx = (transform['cx'] / downscale) if 'cx' in transform else (self.W / 2)
        cy = (transform['cy'] / downscale) if 'cy' in transform else (self.H / 2)
    
        self.intrinsics = np.array([fl_x, fl_y, cx, cy])


    def collate(self, index):

        B = len(index) # a list of length 1

        # random pose without gt images.
        if self.rand_pose == 0 or index[0] >= len(self.poses):

            poses = rand_poses(B, self.device, radius=self.radius)

            # sample a low-resolution but full image for CLIP
            s = np.sqrt(self.H * self.W / self.num_rays) # only in training, assert num_rays > 0
            rH, rW = int(self.H / s), int(self.W / s)
            rays = get_rays(poses, self.intrinsics / s, rH, rW, -1)

            return {
                'H': rH,
                'W': rW,
                'rays_o': rays['rays_o'],
                'rays_d': rays['rays_d'],    
            }

        poses = self.poses[index].to(self.device) # [B, 4, 4]

        error_map = None if self.error_map is None else self.error_map[index]
        
        rays = get_rays(poses, self.intrinsics, self.H, self.W, self.num_rays, error_map, self.opt.patch_size)

        results = {
            'H': self.H,
            'W': self.W,
            'rays_o': rays['rays_o'],
            'rays_d': rays['rays_d'],
        }

        if self.images is not None:
            images = self.images[index].to(self.device) # [B, H, W, 3/4]
            if self.training:
                C = images.shape[-1]
                images = torch.gather(images.view(B, -1, C), 1, torch.stack(C * [rays['inds']], -1)) # [B, N, 3/4]
            results['images'] = images
        
        # need inds to update error_map
        if error_map is not None:
            results['index'] = index
            results['inds_coarse'] = rays['inds_coarse']
            
        return results

    def dataloader(self):
        size = len(self.poses)
        if self.training and self.rand_pose > 0:
            size += size // self.rand_pose # index >= size means we use random pose.
        loader = DataLoader(list(range(size)), batch_size=1, collate_fn=self.collate, shuffle=self.training, num_workers=0)
        loader._data = self # an ugly fix... we need to access error_map & poses in trainer.
        loader.has_gt = self.images is not None
        return loader

class MetaNeRFInversionDataset(Dataset):
    def __init__(self, objects, opt, device, type='train', downscale=1, n_test=10) -> None:
        super().__init__()

        self.opt = opt
        self.device = device
        self.type = type # train, val, test
        self.downscale = downscale
        if os.path.exists(os.path.join(opt.path, "data_full")):
            self.root_path = os.path.join(opt.path, "data_full")
            print(f'Data path exists, root path is {self.root_path}')
            self.listing_path = os.path.join(opt.path, "abo_listings/listings/metadata")
        else:
            self.root_path = os.path.join(opt.path, "ABO_rendered")
            self.listing_path = os.path.join(opt.path, "ABO_listings/listings/metadata")
        self.n_test = n_test
        # self.class_choice = class_choice
        self.preload = opt.preload # preload data into GPU
        self.scale = opt.scale # camera radius scale to make sure camera are inside the bounding box.
        self.offset = opt.offset # camera offset
        self.bound = opt.bound # bounding box half length, also used as the radius to random sample poses.
        self.fp16 = opt.fp16 # if preload, load into fp16.
        self.error_map = None
        self.training = self.type in ['train', 'all', 'trainval']
        self.num_rays = self.opt.num_rays if self.training else -1
        self.clip_mapping = opt.clip_mapping
        self.test = self.type in ['test']
        # if self.test:
        #     self.scale = 0.5
        self.multiview_inversion = opt.multiview_inversion # invert from multiple input poses

        all_object_poses = [4, 5, 6, 7, 10, 11, 12, 13, 14, 15, 16, 18, 19, 21, 22, 23, 25, 26, 28, 29, 30, 31, 33, 35, 36, 37, 48, 49, 50, 52, 53, 54, 56, 57, 58, 59, 77, 82]
        num_multiview_poses = 3 # number of multiview poses to invert from
        self.multiview_poses = random.sample(all_object_poses, num_multiview_poses)
        # self.multiview_poses = random.shuffle(all_object_poses)[:num_multiview_poses]

        split_type = "test"
        self.mode = 'colmap'

        self.objects = objects
    
    def __len__(self):
        if not self.training:
            return 91
        if self.opt.invert:
            return 100

        return len(self.objects)

    def get_random_rays(self,images,poses,intrinsics,H,W,index=None):
        if not self.test:
            images_original = images.clone()
        else:
            images_original = None
        if index is None:
            index = [0]
        B = len(index) # a list of length 1

        rays = get_rays(poses[index], intrinsics, H, W, self.num_rays, patch_size=self.opt.patch_size)
        results = {
            'H': H,
            'W': W,
            'rays_o': rays['rays_o'],
            'rays_d': rays['rays_d'],
            'patch_size': self.opt.patch_size,
        }

        if images is not None:
            images = images[index]
            if self.training:
                C = images.shape[-1]
                images = torch.gather(images.view(B, -1, C), 1, torch.stack(C * [rays['inds']], -1)) # [B, N, 3/4]
            results['images'] = images
        
        results['poses'] = poses
        results['intrinsics'] = intrinsics
        if images_original is not None:
            image_clip = images_original[index].squeeze(0)
            image_clip = image_clip[..., :3] * image_clip[..., 3:]

            results['img_original'] = image_clip
        results['num_rays'] = self.num_rays

        return results
    
    '''
    in the __getitem__ function, index is the index of the object of a specific class
    self.index is the pose and index is the item
    '''
    def __getitem__(self, index):
        # else:
        data_path = os.path.join(self.root_path, self.objects[self.object_index])

        '''
        if multiview inversion is used -> sample random pose index
        '''
        if self.multiview_inversion:
            self.pose_index = self.multiview_poses[random.randint(0, len(self.multiview_poses)-1)]
            # print(f'Using multiview inversion, sampled pose index : {self.pose_index}') # Debug

        if self.test:
            # data_path = './'
            self.pose_index = index
        
        # load nerf-compatible format data.
        if self.mode == 'colmap':
            # if self.test:
            #     transforms_file = 'transforms.json'
            # else:
            transforms_file = 'metadata.json'
            with open(os.path.join(data_path, transforms_file), 'r') as f:
                transform = json.load(f)
        elif self.mode == 'blender':
            # load all splits (train/valid/test), this is what instant-ngp in fact does...
            if self.type == 'all':
                transform_paths = glob.glob(os.path.join(data_path, '*.json'))
                transform = None
                for transform_path in transform_paths:
                    with open(transform_path, 'r') as f:
                        tmp_transform = json.load(f)
                        if transform is None:
                            transform = tmp_transform
                        else:
                            transform['frames'].extend(tmp_transform['frames'])
            # load train and val split
            elif self.type == 'trainval':
                with open(os.path.join(data_path, f'transforms_train.json'), 'r') as f:
                    transform = json.load(f)
                with open(os.path.join(data_path, f'transforms_val.json'), 'r') as f:
                    transform_val = json.load(f)
                transform['frames'].extend(transform_val['frames'])
            # only load one specified split
            else:
                with open(os.path.join(data_path, f'transforms_{self.type}.json'), 'r') as f:
                    transform = json.load(f)

        else:
            raise NotImplementedError(f'unknown dataset mode: {self.mode}')

        # load image size
        if 'h' in transform and 'w' in transform:
            H = int(transform['h']) // self.downscale
            W = int(transform['w']) // self.downscale
        else:
            # we have to actually read an image to get H and W later.
            # H = W = int(512 * self.scale_factor)
            H = W = 512
        
        # read images
        # if self.test:
        #     frames = transform['frames']
        # else:
        frames = transform["views"]

        if self.mode == 'blender' and self.type == 'test':
            # choose two random poses, and interpolate between.
            f0, f1 = np.random.choice(frames, 2, replace=False)
            pose0 = nerf_matrix_to_ngp(np.array(f0['transform_matrix'], dtype=np.float32), scale=self.scale, offset=self.offset) # [4, 4]
            pose1 = nerf_matrix_to_ngp(np.array(f1['transform_matrix'], dtype=np.float32), scale=self.scale, offset=self.offset) # [4, 4]
            rots = Rotation.from_matrix(np.stack([pose0[:3, :3], pose1[:3, :3]]))
            slerp = Slerp([0, 1], rots)

            poses = []
            images = None
            for i in range(self.n_test + 1):
                ratio = np.sin(((i / self.n_test) - 0.5) * np.pi) * 0.5 + 0.5
                pose = np.eye(4, dtype=np.float32)
                pose[:3, :3] = slerp(ratio).as_matrix()
                pose[:3, 3] = (1 - ratio) * pose0[:3, 3] + ratio * pose1[:3, 3]
                poses.append(pose)

        else:
            poses = []
            images = []
            if self.test:
                images = None
            _iter = 0
            for f in frames:
                if _iter != self.pose_index:
                    _iter += 1
                    continue
                f_path = os.path.join(data_path,'render','0',f'render_{_iter}.jpg')
                if not os.path.exists(f_path):
                    f_path = os.path.join(data_path,'render','1',f'render_{_iter}.jpg')
                if not os.path.exists(f_path):
                    f_path = os.path.join(data_path,'render','2',f'render_{_iter}.jpg')
                if not os.path.exists(f_path):
                    if not self.test:
                        print('path not found')
                        exit()
                if self.opt.finetune:
                    f_path = os.path.join('fzzq9',str(self.object_index),f'ngp_ep1780_00{_iter}_inference.png')
                seg_path = os.path.join(data_path,'segmentation',f'segmentation_{_iter}.jpg')
                if self.mode == 'colmap' and '.' not in os.path.basename(f_path):
                        f_path += '.png' # so silly...
                # if self.test:
                #     pose = np.array(f['transform_matrix'], dtype=np.float32).reshape(4,4) # [4, 4]
                # else:
                pose = np.array(f['pose'], dtype=np.float32).reshape(4,4) # [4, 4]
                pose = nerf_matrix_to_ngp(pose, scale=self.scale, offset=self.offset)

                if not self.test:
                    image_without_mask = cv2.imread(f_path, cv2.IMREAD_UNCHANGED) # [H, W, 3] o [H, W, 4]
                    mask = cv2.imread(seg_path, cv2.IMREAD_UNCHANGED)
                    mask = np.expand_dims(mask, axis=-1)
                    image = np.concatenate([image_without_mask.astype(np.float32),mask.astype(np.float32)],axis=-1)
                    # cv2.imwrite('test_table.png', image)
                    # save the input pose
                    cv2.imwrite(osp.join(self.opt.workspace, f'input_pose_{self.pose_index}.png'), image)
                    if H is None or W is None:
                        H = image.shape[0] // self.downscale
                        W = image.shape[1] // self.downscale

                    # add support for the alpha channel as a mask.
                    if image.shape[-1] == 3: 
                        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
                    else:
                        image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)

                    if image.shape[0] != H or image.shape[1] != W:
                        image = cv2.resize(image, (W, H), interpolation=cv2.INTER_AREA)
                        
                    image = image.astype(np.float32) / 255.0 # [H, W, 3/4]
                    images.append(image)

                poses.append(pose)
                _iter += 1
            
        poses = torch.from_numpy(np.stack(poses, axis=0)) # [N, 4, 4]

        if images is not None:
            images = torch.from_numpy(np.stack(images, axis=0)) # [N, H, W, C]
    
            if images.isnan().any() or images.isinf().any():
                print("NAAAAAAAANAAAAAAN")

        # calculate mean radius of all camera poses
        radius = poses[:, :3, 3].norm(dim=-1).mean(0).item()

        # load intrinsics
        if 'fl_x' in transform or 'fl_y' in transform:
            fl_x = (transform['fl_x'] if 'fl_x' in transform else transform['fl_y']) / self.downscale
            fl_y = (transform['fl_y'] if 'fl_y' in transform else transform['fl_x']) / self.downscale
        elif 'camera_angle_x' in transform or 'camera_angle_y' in transform:
            # blender, assert in radians. already downscaled since we use H/W
            fl_x = W / (2 * np.tan(transform['camera_angle_x'] / 2)) if 'camera_angle_x' in transform else None
            fl_y = H / (2 * np.tan(transform['camera_angle_y'] / 2)) if 'camera_angle_y' in transform else None
            if fl_x is None: fl_x = fl_y
            if fl_y is None: fl_y = fl_x
        else:
            fl_x = fl_y = 443.40496826171875

        cx = (transform['cx'] / self.downscale) if 'cx' in transform else (W / 2)
        cy = (transform['cy'] / self.downscale) if 'cy' in transform else (H / 2)
    
        intrinsics = np.array([fl_x, fl_y, cx, cy])
        results = self.get_random_rays(images,poses,intrinsics,H,W)
        return results, self.object_index

class MetaNeRFDataset(Dataset):
    def __init__(self, opt, device, type='train', downscale=1, global_pose_index=0, n_test=10, class_choice = 'chair') -> None:
        super().__init__()

        self.opt = opt
        self.device = device
        self.type = type # train, val, test
        self.downscale = downscale
        if os.path.exists(os.path.join(opt.path, "data")):
            self.root_path = os.path.join(opt.path, "data")
            self.listing_path = os.path.join(opt.path, "abo_listings/listings/metadata")
        else:
            self.root_path = os.path.join(opt.path, "ABO_rendered")
            self.listing_path = os.path.join(opt.path, "ABO_listings/listings/metadata")
        self.n_test = n_test
        self.class_choice = class_choice
        self.preload = opt.preload # preload data into GPU
        self.scale = opt.scale # camera radius scale to make sure camera are inside the bounding box.
        self.offset = opt.offset # camera offset
        self.bound = opt.bound # bounding box half length, also used as the radius to random sample poses.
        self.fp16 = opt.fp16 # if preload, load into fp16.
        self.error_map = None
        self.training = self.type in ['train', 'all', 'trainval']
        self.num_rays = self.opt.num_rays if self.training else -1
        self.clip_mapping = opt.clip_mapping

        self.global_pose_index = global_pose_index

        split_type = "train"

        with open(os.path.join(self.root_path, "train_test_split.csv")) as f:
            train_test_split = f.read().splitlines()
            train_split = [x.split(",")[0] for x in train_test_split if "TRAIN" in x]
            test_split = [x.split(",")[0] for x in train_test_split if "TEST" in x]

            split = train_split if split_type == "train" else test_split

        self.rand_pose = opt.rand_pose
        self.mode = 'colmap'

        self.objects = os.listdir(self.root_path)
        try:
            self.objects.remove('README.md')
            self.objects.remove('test_sample_idx.json')
            self.objects.remove('train_sample_idx.json')
            self.objects.remove('train_test_split.csv')
        except:
            pass

        # if self.training:
        #     self.objects =self.objects[:6900]
        # else:
        #     self.objects =self.objects[:6900]

        for path in self.objects:
            data_path = os.path.join(self.root_path,path)
            if not (os.path.exists(os.path.join(data_path, 'metadata.json'))):
                self.objects.remove(path)

        with gzip.open(os.path.join(self.listing_path,"listings_0.json.gz"), mode="r") as f:
            self.metadata = [json.loads(line) for line in f]
        
        load_all = True
        if load_all:
            with gzip.open(os.path.join(self.listing_path,"listings_1.json.gz"), mode="r") as f:
                for line in f:
                    json_dict= json.loads(line)
                    self.metadata.append(json_dict)
            with gzip.open(os.path.join(self.listing_path,"listings_2.json.gz"), mode="r") as f:
                for line in f:
                    json_dict= json.loads(line)
                    self.metadata.append(json_dict)
            with gzip.open(os.path.join(self.listing_path,"listings_3.json.gz"), mode="r") as f:
                for line in f:
                    json_dict= json.loads(line)
                    self.metadata.append(json_dict)
            with gzip.open(os.path.join(self.listing_path,"listings_4.json.gz"), mode="r") as f:
                for line in f:
                    json_dict= json.loads(line)
                    self.metadata.append(json_dict)
            with gzip.open(os.path.join(self.listing_path,"listings_5.json.gz"), mode="r") as f:
                for line in f:
                    json_dict= json.loads(line)
                    self.metadata.append(json_dict)
            with gzip.open(os.path.join(self.listing_path,"listings_6.json.gz"), mode="r") as f:
                for line in f:
                    json_dict= json.loads(line)
                    self.metadata.append(json_dict)
            with gzip.open(os.path.join(self.listing_path,"listings_7.json.gz"), mode="r") as f:
                for line in f:
                    json_dict= json.loads(line)
                    self.metadata.append(json_dict)
            with gzip.open(os.path.join(self.listing_path,"listings_8.json.gz"), mode="r") as f:
                for line in f:
                    json_dict= json.loads(line)
                    self.metadata.append(json_dict)
            with gzip.open(os.path.join(self.listing_path,"listings_9.json.gz"), mode="r") as f:
                for line in f:
                    json_dict= json.loads(line)
                    self.metadata.append(json_dict)
            with gzip.open(os.path.join(self.listing_path,"listings_a.json.gz"), mode="r") as f:
                for line in f:
                    json_dict= json.loads(line)
                    self.metadata.append(json_dict)
            with gzip.open(os.path.join(self.listing_path,"listings_b.json.gz"), mode="r") as f:
                for line in f:
                    json_dict= json.loads(line)
                    self.metadata.append(json_dict)
            with gzip.open(os.path.join(self.listing_path,"listings_c.json.gz"), mode="r") as f:
                for line in f:
                    json_dict= json.loads(line)
                    self.metadata.append(json_dict)
            with gzip.open(os.path.join(self.listing_path,"listings_d.json.gz"), mode="r") as f:
                for line in f:
                    json_dict= json.loads(line)
                    self.metadata.append(json_dict)
            with gzip.open(os.path.join(self.listing_path,"listings_e.json.gz"), mode="r") as f:
                for line in f:
                    json_dict= json.loads(line)
                    self.metadata.append(json_dict)
            with gzip.open(os.path.join(self.listing_path,"listings_f.json.gz"), mode="r") as f:
                for line in f:
                    json_dict= json.loads(line)
                    self.metadata.append(json_dict)
        
        types = []
        self.objects.sort()
        for d in self.metadata:
            try:
                types.append(d['item_id'])
            except:
                print(d)

        print(f"Total objects in ABO Dataset: {len(types)}")
        filtered_objects = []
        print(f"Total rendered objects {len(self.objects)}")
        for val in self.objects:
            if val in types and val in split:
                if self.metadata[types.index(val)]['product_type'][0]['value'].lower() == self.class_choice.lower():
                    filtered_objects.append(val)

        self.objects = filtered_objects.copy()
        print(f"Total rendered Chairs: {len(filtered_objects)}")
    
    def __len__(self):
        return len(self.objects)

    def get_random_rays(self,images,poses,intrinsics,H,W,index=None):
        images_original = images.clone()
        if index is None:
            index = [0]
        B = len(index) # a list of length 1

        rays = get_rays(poses[index], intrinsics, H, W, self.num_rays, patch_size=self.opt.patch_size)
        results = {
            'H': H,
            'W': W,
            'rays_o': rays['rays_o'],
            'rays_d': rays['rays_d'],
            'patch_size': self.opt.patch_size,
        }

        if images is not None:
            images = images[index]
            if self.training:
                C = images.shape[-1]
                images = torch.gather(images.view(B, -1, C), 1, torch.stack(C * [rays['inds']], -1)) # [B, N, 3/4]
            results['images'] = images
        
        results['poses'] = poses
        results['intrinsics'] = intrinsics

        image_clip = images_original[index].squeeze(0)
        image_clip = image_clip[..., :3] * image_clip[..., 3:]

        results['img_original'] = image_clip
        results['num_rays'] = self.num_rays

        return results
    
    def __getitem__(self, index, flag=None, view_pose_dir=None):
        self.index = torch.randint(0,91,size=(1,))
        # [18, 19, 20]
        # failed []
        # done [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17]
        # self.index = self.global_pose_index #torch.tensor([20]).to(torch.int)
        object_id = self.objects[index]
        # self.index is the pose and index is the item.
        
        data_path = os.path.join(self.root_path, object_id)
        
        # load nerf-compatible format data.
        if self.mode == 'colmap':
            with open(os.path.join(data_path, 'metadata.json'), 'r') as f:
                transform = json.load(f)
        elif self.mode == 'blender':
            # load all splits (train/valid/test), this is what instant-ngp in fact does...
            if self.type == 'all':
                transform_paths = glob.glob(os.path.join(data_path, '*.json'))
                transform = None
                for transform_path in transform_paths:
                    with open(transform_path, 'r') as f:
                        tmp_transform = json.load(f)
                        if transform is None:
                            transform = tmp_transform
                        else:
                            transform['frames'].extend(tmp_transform['frames'])
            # load train and val split
            elif self.type == 'trainval':
                with open(os.path.join(data_path, f'transforms_train.json'), 'r') as f:
                    transform = json.load(f)
                with open(os.path.join(data_path, f'transforms_val.json'), 'r') as f:
                    transform_val = json.load(f)
                transform['frames'].extend(transform_val['frames'])
            # only load one specified split
            else:
                with open(os.path.join(data_path, f'transforms_{self.type}.json'), 'r') as f:
                    transform = json.load(f)

        else:
            raise NotImplementedError(f'unknown dataset mode: {self.mode}')

        # load image size
        if 'h' in transform and 'w' in transform:
            H = int(transform['h']) // self.downscale
            W = int(transform['w']) // self.downscale
        else:
            # we have to actually read an image to get H and W later.
            # H = W = int(512 * self.scale_factor)
            H = W = 512
        
        # read images
        frames = transform["views"]

        if self.mode == 'blender' and self.type == 'test':
            # choose two random poses, and interpolate between.
            f0, f1 = np.random.choice(frames, 2, replace=False)
            pose0 = nerf_matrix_to_ngp(np.array(f0['transform_matrix'], dtype=np.float32), scale=self.scale, offset=self.offset) # [4, 4]
            pose1 = nerf_matrix_to_ngp(np.array(f1['transform_matrix'], dtype=np.float32), scale=self.scale, offset=self.offset) # [4, 4]
            rots = Rotation.from_matrix(np.stack([pose0[:3, :3], pose1[:3, :3]]))
            slerp = Slerp([0, 1], rots)

            poses = []
            images = None
            for i in range(self.n_test + 1):
                ratio = np.sin(((i / self.n_test) - 0.5) * np.pi) * 0.5 + 0.5
                pose = np.eye(4, dtype=np.float32)
                pose[:3, :3] = slerp(ratio).as_matrix()
                pose[:3, 3] = (1 - ratio) * pose0[:3, 3] + ratio * pose1[:3, 3]
                poses.append(pose)

        else:
            poses = []
            images = []
            _iter = 0
            for f in frames:
                if _iter != self.index: # rendering just one pose it seems.
                    _iter += 1
                    continue
                f_path = os.path.join(data_path,'render','0',f'render_{_iter}.jpg')
                if not os.path.exists(f_path):
                    f_path = os.path.join(data_path,'render','1',f'render_{_iter}.jpg')
                if not os.path.exists(f_path):
                    f_path = os.path.join(data_path,'render','2',f'render_{_iter}.jpg')
                if not os.path.exists(f_path):
                    print('path not found')
                    exit()
                seg_path = os.path.join(data_path,'segmentation',f'segmentation_{_iter}.jpg')
                if self.mode == 'colmap' and '.' not in os.path.basename(f_path):
                        f_path += '.png' # so silly...
                
                pose = np.array(f['pose'], dtype=np.float32).reshape(4,4) # [4, 4]
                pose = nerf_matrix_to_ngp(pose, scale=self.scale, offset=self.offset)

                image_without_mask = cv2.imread(f_path, cv2.IMREAD_UNCHANGED) # [H, W, 3] o [H, W, 4]
                mask = cv2.imread(seg_path, cv2.IMREAD_UNCHANGED)
                mask = np.expand_dims(mask, axis=-1)
                image = np.concatenate([image_without_mask.astype(np.float32),mask.astype(np.float32)],axis=-1)
                
                if H is None or W is None:
                    H = image.shape[0] // self.downscale
                    W = image.shape[1] // self.downscale

                # add support for the alpha channel as a mask.
                if image.shape[-1] == 3: 
                    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
                else:
                    image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)

                if image.shape[0] != H or image.shape[1] != W:
                    image = cv2.resize(image, (W, H), interpolation=cv2.INTER_AREA)
                    
                image = image.astype(np.float32) / 255.0 # [H, W, 3/4]
                images.append(image)

                poses.append(pose)
                _iter += 1
            
        poses = torch.from_numpy(np.stack(poses, axis=0)) # [N, 4, 4]

        if images is not None:
            images = torch.from_numpy(np.stack(images, axis=0)) # [N, H, W, C]
    
            if images.isnan().any() or images.isinf().any():
                print("NAAAAAAAANAAAAAAN")

        # calculate mean radius of all camera poses
        radius = poses[:, :3, 3].norm(dim=-1).mean(0).item()

        # load intrinsics
        if 'fl_x' in transform or 'fl_y' in transform:
            fl_x = (transform['fl_x'] if 'fl_x' in transform else transform['fl_y']) / self.downscale
            fl_y = (transform['fl_y'] if 'fl_y' in transform else transform['fl_x']) / self.downscale
        elif 'camera_angle_x' in transform or 'camera_angle_y' in transform:
            # blender, assert in radians. already downscaled since we use H/W
            fl_x = W / (2 * np.tan(transform['camera_angle_x'] / 2)) if 'camera_angle_x' in transform else None
            fl_y = H / (2 * np.tan(transform['camera_angle_y'] / 2)) if 'camera_angle_y' in transform else None
            if fl_x is None: fl_x = fl_y
            if fl_y is None: fl_y = fl_x
        else:
            fl_x = fl_y = 443.40496826171875

        cx = (transform['cx'] / self.downscale) if 'cx' in transform else (W / 2)
        cy = (transform['cy'] / self.downscale) if 'cy' in transform else (H / 2)
    
        intrinsics = np.array([fl_x, fl_y, cx, cy])
        results = self.get_random_rays(images, poses, intrinsics, H, W)

        results["filename"] = f"{object_id}_{self.index.item()}"
        return results, index
        
        

class Inversion_Dataset(Dataset):
    def __init__(self, opt, device, type='train', downscale=1, global_pose_index=0, n_test=10, class_choice = 'chair') -> None:
        super().__init__()
        self.opt = opt
        self.device = device
        self.type = type

    def __len__(self):
        return 100

    def get_random_rays(self,images,poses,intrinsics,H,W,index=None):
        images_original = images.clone()
        if index is None:
            index = [0]
        B = len(index) # a list of length 1

        rays = get_rays(poses[index], intrinsics, H, W, self.opt.num_rays, patch_size=self.opt.patch_size)
        results = {
            'H': H,
            'W': W,
            'rays_o': rays['rays_o'],
            'rays_d': rays['rays_d'],
            'patch_size': self.opt.patch_size,
        }

        if images is not None:
            images = images[index]
            if self.training:
                C = images.shape[-1]
                images = torch.gather(images.view(B, -1, C), 1, torch.stack(C * [rays['inds']], -1)) # [B, N, 3/4]
            results['images'] = images
        
        results['poses'] = poses
        results['intrinsics'] = intrinsics

        image_clip = images_original[index].squeeze(0)
        image_clip = image_clip[..., :3] * image_clip[..., 3:]

        results['img_original'] = image_clip
        results['num_rays'] = self.num_rays

        return results
    
    def __getitem__(self, index, flag=None, view_pose_dir=None):
        
        index = 0
        data_path = os.path.join(self.root_path, object_id)
        if self.type == "train":
        
            with open(os.path.join(data_path, f'transforms_{self.type}.json'), 'r') as f:
                transform = json.load(f)
                
            poses = np.indentity(4)  # 1x4x4
            img   = np.indentity(4)  # 1xhxwhxx
            intrinsics = np.indentity(4) # [fx,fy,cx,cy]           
       
            results = self.get_random_rays(None, poses, intrinsics, H, W)

            
        
        if self.type == "test":
            with open(os.path.join(data_path, f'transforms_{self.type}.json'), 'r') as f:
                    transform = json.load(f)
            frames = transform["frames"]  
            poses = []
            for _iter,f in enumerate(frames):
                f_path = os.path.join(data_path, os.path.basename(f['file_path']))
                if self.mode == 'colmap' and '.' not in os.path.basename(f_path):
                        f_path += '.png' # so silly...
                
                pose = np.array(f['transform_matrix'], dtype=np.float32).reshape(4,4)
                pose = nerf_matrix_to_ngp(pose, scale=self.opt.scale, offset=self.opt.offset)
                poses.append(pose)
                
            results = self.get_random_rays(None, poses, intrinsics, H, W, index=index)

          
                
                
                
        
        return results, index