import os
from torch.utils.data import Dataset, DataLoader
import numpy as np

import cv2
import torch
from time import sleep, time
from torchvision.transforms import Resize
from multiprocessing import Pool
import pandas as pd


class Eventvot(Dataset):
    def __init__(self, data_path, split='train', args=None):
        self.data_path = data_path
        # self.data_name = sorted(os.listdir(data_path))
        # self.data_ls = sorted(list(map(lambda x:os.path.join(data_path,x),self.data_name)))
        self.fps = 60
        self.crop_size = [128,128]
        self.num_bins = 3
        self.width = 180
        self.height = 180
        self.device = 'cpu'
        self.cond_length = 3
        self.seq_length = 25
        self.csv_data = pd.read_csv(self.data_path)
    
    def __len__(self):
        return len(self.csv_data)
    
    def __getname__(self): 
        return 'vot' 

    def __getitem__(self, index):
        # tic = time()
        # event_pth = os.path.join(self.data_ls[index],self.data_name[index])+'.csv'
        # event_npy = os.path.join(self.data_ls[index],self.data_name[index])+'.npy'
        # seq_name = 'recording_2022-10-14_11-03-25'
        # event_pth = os.path.join(self.data_path,seq_name,seq_name)+'_voxel.npy'
        event_seq_pth = self.csv_data.iloc[index, 0]
        sub_pth = [file for file in os.listdir(event_seq_pth) if 'csv' not in file and 'voxel_crop' in file]
        # print(sub_pth)
        event_pth = os.path.join(event_seq_pth,sub_pth[0])

        event_data = np.load(event_pth)
        # toc = time()
        max_idx = event_data.shape[0]
        event_seq_length = self.cond_length+self.seq_length
        if max_idx-event_seq_length-1 <= 0:
            print(event_seq_pth)
        event_start_idx = np.random.randint(0,max_idx-event_seq_length-1)
        # print(event_start_idx)


        crop_x = np.random.randint(0,self.width-self.crop_size[0])
        crop_y = np.random.randint(0,self.height-self.crop_size[1])

        events_voxel_cat = torch.from_numpy(event_data[event_start_idx:event_start_idx+event_seq_length])
        a = abs(events_voxel_cat.max())
        b = abs(events_voxel_cat.min())
        max_norm = a if a>b else b
        events_voxel_cat = events_voxel_cat/max_norm
        events_voxel_cat = events_voxel_cat[:,:,crop_y:crop_y+self.crop_size[0],crop_x:crop_x+self.crop_size[1]]

        # for i in range(event_seq_length):
        #     event_tensor = events_voxel_cat[i]

        #     event_tensor = event_tensor.permute(1,2,0).cpu().numpy()
        #     event_tensor = (event_tensor+1)*127.5
        #     # event_tensor = cv2.normalize(event_tensor,None,0,255,cv2.NORM_MINMAX)
        #     cv2.imwrite('event_vot_look/vot_look_cat_%d.png' % i,event_tensor)

        event0 = (events_voxel_cat[:3] +1)/2 

        # print(event0.max(),event0.min())
        
        # print("Done:", f"Time: {toc - tic:.3f}s")

        return {"pixel_values": events_voxel_cat[3:], "image": event0, 'dataset': self.__getname__()}




if __name__ == '__main__':
    tic = time()
    trainset = Eventvot('train_local_hr.csv')
    
    batch = trainset[1]
    event_tensor = batch['pixel_values']
    image = batch['image']
    print(event_tensor.shape,image.shape)
    print(len(trainset))
    toc = time()
    print("Done:", f"Time: {toc - tic:.3f}s")