import os
import math
import json
import torch
import numpy as np
from PIL import Image
from tqdm import tqdm
from torch.utils.data import Dataset
from .ray_utils import *


class MultiscaleDataset(Dataset):
    def __init__(self, datadir, split="train", downsample=1, is_stack=False, N_vis=-1, n_scales=4):
        super(MultiscaleDataset, self).__init__()
        assert n_scales <= 5

        self.N_vis = N_vis
        self.root_dir = datadir
        self.split = split
        self.is_stack = is_stack
        self.n_scales = n_scales

        self.scene_bbox = torch.tensor([[-1.5, -1.5, -1.5], [1.5, 1.5, 1.5]])
        self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])

        self.white_bg = True
        self.near_far = [2.0, 6.0]

        self._read_meta()

    def _read_meta(self):
        with open(os.path.join(self.root_dir, f"metadata.json"), "r") as f:
            self.meta = json.load(f)[self.split]

        self.image_paths = []
        self.poses = []
        self.all_rays = []
        self.all_rgbs = []
        self.all_scales = []
        self.all_lossmults = []
        self.all_heights = []
        self.all_widths = []

        # hard-coded
        if self.n_scales <= 4:
            img_scales = torch.linspace(1, -1, len(np.unique(self.meta["label"])))
        else:
            img_scales = torch.linspace(0.75, -0.75, len(np.unique(self.meta["label"])))

        img_eval_interval = 1 if self.N_vis < 0 else len(self.meta["file_path"]) // self.N_vis
        assert img_eval_interval == 1
        
        idxs = list(range(0, len(self.meta["file_path"]), img_eval_interval))
        for i in tqdm(idxs, position=0, leave=True, desc=f'Loading data {self.split} ({len(idxs)})'):#img_list:#
            # camera pose (follows opencv coordinate system)
            c2w = torch.tensor(np.array(self.meta["cam2world"][i]) @ self.blender2opencv, dtype=torch.float32)
            self.poses.append(c2w)

            # rgb images
            img = self._transform(Image.open(os.path.join(self.root_dir, self.meta["file_path"][i])))
            self.all_rgbs.append(img)

            # image dimensions
            height = self.meta["height"][i]
            width = self.meta["width"][i]
            self.all_heights.append(height)
            self.all_widths.append(width)

            # focal length
            focal = self.meta["focal"][i]

            # rays from driections (cx = W/2, cy = H/2)
            directions = get_ray_directions(height, width, (focal, focal))
            rays_o, rays_d = get_rays(directions, c2w) # both (h * s_scale * w * s_scale, 3)
            self.all_rays.append(torch.cat([rays_o, rays_d], 1)) # (h*w, 6)

            # loss multipliers
            lossmult = self.meta["lossmult"][i]
            assert math.sqrt(lossmult) % 1 == 0
            lossmult_t = torch.tensor(lossmult, dtype=torch.float32)
            self.all_lossmults.append(torch.broadcast_to(lossmult_t, (height * width, 1)))

            # scales
            label = self.meta["label"][i]
            self.all_scales.append(torch.broadcast_to(img_scales[label], (height * width, 1)))

        self.poses = torch.stack(self.poses)
        self.all_lossmults = torch.cat(self.all_lossmults)
        self.all_heights = torch.tensor(self.all_heights, dtype=torch.int64)
        self.all_widths = torch.tensor(self.all_widths, dtype=torch.int64)
        if not self.is_stack:
            self.all_rays = torch.cat(self.all_rays, 0)  # (#Frames*h*w, 3)
            self.all_rgbs = torch.cat(self.all_rgbs, 0)  # (#Frames*h*w, 3)
            self.all_scales = torch.cat(self.all_scales, 0) # (#Frames*h*w, 3)
        else:
            self.all_rgbs = [img.view(h, w, 3) for img, h, w in zip(self.all_rgbs, self.all_heights, self.all_widths)] # (#Frames,h,w,3)

    def _transform(self, x):
        # PIL Image to np.array
        x = np.array(x, dtype=np.float32) / 255.
        # alpha blending (h, w, 3)
        x = x[..., :3] * x[..., -1:] + (1. - x[..., -1:])
        # to tensor & flatten 
        x = torch.from_numpy(x).view(-1, 3)
        return x
    
    def __getitem__(self, idx):
        if self.split == "train":
            sample = {
                "rays": self.all_rays[idx],
                "rgbs": self.all_rgbs[idx],
                "scales": self.all_scales[idx],
                "lossmults": self.all_lossmults[idx]
            }
        else:
            sample = {
                "rays": self.all_rays[idx],
                "rgbs": self.all_rgbs[idx],
                "scales": self.all_scales[idx],
                "height": self.all_heights[idx],
                "width": self.all_widths[idx]
            }
        return sample


if __name__ == "__main__":
    dataset = MultiscaleDataset(
        datadir="/workspace/dataset/nerf_synthetic_multiscale/lego",
        split="test",
        is_stack=True,
        N_vis=-1,
    )

    import pdb; pdb.set_trace()



