import os
import pickle
import torch
import numpy as np
from torch.utils.data import Dataset

from tqdm import trange
from PIL import Image

class NYUv2(Dataset):
    def __init__(self, root, split, transform=None, target_transform=None):
        super().__init__()
        self.root = root
        self.split = split
        self.transform = transform
        self.target_transform = target_transform

        self.image_dir = os.path.join(self.root, 'images', self.split)
        self.depth_dir = os.path.join(self.root, 'annotations', self.split)

        self.images = sorted(os.listdir(self.image_dir))
        self.depths = sorted(os.listdir(self.depth_dir))
        print(len(self.images), len(self.depths))

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.images[idx])
        depth_path = os.path.join(self.depth_dir, self.depths[idx])

        img = Image.open(img_path)
        img = self.transform(img)

        depth = Image.open(depth_path)
        depth = self.target_transform(depth)
        depth = depth.float() / 1e3

        return img, depth

    def __len__(self):
        return len(self.images)
