import os
import time
import pickle
import numpy as np

import torch
from torch.autograd import Variable
from torch.utils.data import Dataset

class DynamicsDataset(Dataset):

    def __init__(self, config, phase):
        self.config = config

        # dynamics data expected to be a single pickle file
        # should be a list of numpy matrices
        # each matrix represents an episode in step order
        # and is of shape (# steps, state_dim + action_dim) representing (s, a)
        with open("data/%s" % self.config['data']['data_name'], 'rb') as fp:
            data_load = pickle.load(fp)

        state_dim = self.config['data']['state_dim']
        action_dim = self.config['data']['action_dim']

        n_his = self.config['train']['n_history']
        n_roll = self.config['train']['n_rollout']
        n_sample = n_his + n_roll

        self.obs = []
        self.act = []

        for ep in data_load:
            for i in range(len(ep) - n_sample + 1):
                self.obs.append(ep[i:i+n_sample,:state_dim])
                self.act.append(ep[i:i+n_sample,state_dim:])
        self.obs = np.array(self.obs)
        self.act = np.array(self.act)

        # shuffle together
        idx = np.random.permutation(range(len(self.obs)))
        self.obs = self.obs[idx]
        self.act = self.act[idx]

        num_train = int(len(self.obs) * config['train']['train_valid_ratio'])

        if phase == 'train':
            self.obs = self.obs[:num_train]
            self.act = self.act[:num_train]

        elif phase == 'valid':
            self.obs = self.obs[num_train:]
            self.act = self.act[num_train:]

        else:
            raise AssertionError("Unknown phase %s" % phase)

    def __len__(self):
        return len(self.obs)

    def __getitem__(self, idx):

        return {'observations': self.obs[idx], 'actions': self.act[idx]}
