
import numpy as np
import joblib
import torch
from torch.utils.data import Dataset, DataLoader


def CUDA(var):
    return var.cuda() if torch.cuda.is_available() else var

def segment_trajectories(data, success_only=True):
    trajectories = []
    start_idx = 0
    count_dones = 0
    for i, done in enumerate(data['dones']):
        if done: 
            count_dones += 1
            if count_dones >= 2000 or count_dones < 1000: 
                continue
        if success_only: 
            if done and data['rewards'][i]:
                trajectory = {
                    'actions': data['actions'][start_idx:i+1],
                    'states': data['states'][start_idx:i+1],
                    'rewards': data['rewards'][start_idx:i+1],
                    'dones': data['dones'][start_idx:i+1]
                }
                trajectories.append(trajectory)
                start_idx = i + 1
        else: 
            if done:
                trajectory = {
                    'actions': data['actions'][start_idx:i+1],
                    'states': data['states'][start_idx:i+1],
                    'rewards': data['rewards'][start_idx:i+1],
                    'dones': data['dones'][start_idx:i+1]
                }
                trajectories.append(trajectory)
                start_idx = i + 1
    return trajectories

class TrajectoryDataset(Dataset):
    def __init__(self, trajectories, context_len=3):
        self.trajectories = trajectories
        self.context_len = context_len
        
    def __len__(self):
        return len(self.trajectories)

    def __getitem__(self, idx):
        trajectory = self.trajectories[idx]
        
        # randomly select the starting index
        assert len(trajectory['rewards']) >= self.context_len
        si = np.random.randint(0, len(trajectory['rewards']) - self.context_len + 1)
        actions = torch.tensor(trajectory['actions'][si: si+self.context_len], dtype=torch.float32)
        states = torch.tensor(trajectory['states'][si: si+self.context_len], dtype=torch.float32)
        rewards = torch.tensor(trajectory['rewards'][si: si+self.context_len], dtype=torch.float32)
        # Scale rewards if necessary (e.g., normalization)
        # rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-5)
        return {
            'actions': actions,
            'states': states,
            'rewards': rewards
        }


class StepwiseDataset(Dataset):
    def __init__(self, trajectories, context_len=3):
        self.trajectories = trajectories
        self.context_len = context_len
        
    def __len__(self):
        return len(self.trajectories)

    def __getitem__(self, idx):
        trajectory = self.trajectories[idx]
        
        # randomly select the starting index
        assert len(trajectory['rewards']) >= self.context_len
        si = np.random.randint(0, len(trajectory['rewards']) - self.context_len + 1)
        actions = torch.tensor(trajectory['actions'][si: si+self.context_len], dtype=torch.float32)
        states = torch.tensor(trajectory['states'][si: si+self.context_len], dtype=torch.float32)
        rewards = torch.tensor(trajectory['rewards'][si: si+self.context_len], dtype=torch.float32)
        # Scale rewards if necessary (e.g., normalization)
        # rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-5)
        return {
            'actions': actions,
            'states': states,
            'rewards': rewards
        }



class Dataset_EBM(Dataset):
    def __init__(self, file_path, type="expert",task="pickc"):
        self.file_path = file_path
        trajs = np.load(os.path.join(self.file_path,f'{task}_{type}_1000.npy'), allow_pickle=True)
        self.obs = np.concatenate([traj['obs'] for traj in trajs], axis=0)


    def __getitem__(self, index): 

        state=self.obs[index]
        state = torch.from_numpy(state).float()
        return state
                
    def __len__(self):
        return len(self.obs)


if __name__ == '__main__': 
    data = np.load("./data/test_data/unlock_IID.npy", allow_pickle=True).item()
    print(np.where(data['rewards']), np.where(data['dones']))
    trajectories = segment_trajectories(data)
    dataset = TrajectoryDataset(trajectories)
    dataloader = DataLoader(dataset, batch_size=1, shuffle=True)  # Typically, batch_size=1 for sequence data

    for batch in dataloader:
        actions = batch['actions'].squeeze(0)
        states = batch['states'].squeeze(0)
        rewards = batch['rewards'].squeeze(0)
        # print(rewards.shape)
    