import torch
import matplotlib.pyplot as plt
import numpy as np
from IPython import embed
import random
import os
import pickle
import argparse
import imageio

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class TrajDataset(torch.utils.data.Dataset):
    def __init__(self, path, config, transform=None, sample=False, mode=None):
        self.config = config
        self.transform = transform
        self.filepath = path
        frac_per_path = [1.0,] * len(self.filepath)
        self.H = config['H']
        self.sample = sample

        if isinstance(self.filepath, list):
            self.trajs = []
            for frac, filepath in zip(frac_per_path, self.filepath):
                file = open(filepath, 'rb')
                trajs = pickle.load(file)
                trajs = trajs[:int(frac * len(trajs))]


                print("Using {} trajs from {}".format(len(trajs), filepath))
                self.trajs += trajs
                file.close()
        else:
            file = open(self.filepath, 'rb')
            self.trajs = pickle.load(file)
            file.close()

        rollin_filepaths = []
        rollin_observations = []
        rollin_poses = []
        rollin_angles = []
        rollin_actions = []
        rollin_rewards = []
        rollin_next_observations = []
        rollin_next_poses = []
        rollin_next_angles = []
        states, actions = [], []
        poses = []
        angles = []

        print(f"Sampling? {self.sample}")

        for i, traj in enumerate(self.trajs):
            if i % 500 == 0:
                print(i)

            rollin_filepaths.append(traj['rollin_obs'])

            state = traj['state']
            state = self.transform(state).float().to(device)
            action = traj['action']
            angle = traj['angle']

            states.append(state)
            actions.append(action)
            poses.append(traj['pose'])
            angles.append(angle)
            rollin_poses.append(traj['rollin_poses'])
            rollin_angles.append(traj['rollin_angles'])
            rollin_actions.append(traj['rollin_us'])
            rollin_rewards.append(traj['rollin_rs'])
            rollin_next_poses.append(traj['rollin_next_poses'])
            rollin_next_angles.append(traj['rollin_next_angles'])

            if not self.sample:
                filepath = traj['rollin_obs']
                rollin_obs = np.load(filepath)
                images = [ self.transform(obs) for obs in rollin_obs ]
                images = torch.stack(images)
                rollin_observations.append(images)

                next_filepath = traj['rollin_next_obs']
                next_rollin_obs = np.load(next_filepath)
                next_images = [ self.transform(obs) for obs in next_rollin_obs ]
                next_images = torch.stack(next_images)
                rollin_next_observations.append(next_images)

        states = torch.stack(states)
        actions = torch.tensor(np.array(actions))
        poses = torch.tensor(np.array(poses))
        angles = torch.tensor(np.array(angles)) # / 360
        rollin_poses = torch.tensor(np.array(rollin_poses))
        rollin_angles = torch.tensor(np.array(rollin_angles))
        rollin_actions = torch.tensor(np.array(rollin_actions))
        rollin_rewards = torch.tensor(np.array(rollin_rewards))
        rollin_next_poses = torch.tensor(np.array(rollin_next_poses))
        rollin_next_angles = torch.tensor(np.array(rollin_next_angles))

        self.ds = {
            'rollin_filepaths': rollin_filepaths,
            'rollin_poses': rollin_poses.float().to(device),
            'rollin_angles': rollin_angles.float().to(device),
            'rollin_actions': rollin_actions.float().to(device),
            'rollin_rewards': rollin_rewards.float().to(device),
            'rollin_next_poses': rollin_next_poses.float().to(device),
            'rollin_next_angles': rollin_next_angles.float().to(device),
            'states': states.float().to(device),
            'actions': actions.float().to(device),
            'poses': poses.float().to(device),
            'angles': angles.float().to(device),
        }
        if not self.sample:
            rollin_observations = torch.stack(rollin_observations)
            self.ds['rollin_observations'] = rollin_observations.float().to(device)
            rollin_next_observations = torch.stack(rollin_next_observations)
            self.ds['rollin_next_observations'] = rollin_next_observations.float().to(device)


    def __len__(self):
        'Denotes the total number of samples'
        return len(self.ds['states'])

    def __getitem__(self, i):
        'Generates one sample of data'


        if self.sample:
            raise NotImplementedError
            filepath = self.ds['rollin_filepaths'][i]
            rollin_obs = np.load(filepath) / 255.0
            images = [ self.transform(obs) for obs in rollin_obs ]
            images = torch.stack(images).float().to(device)
            if self.config['shuffle']:
                permutation = torch.randperm(self.H)
                images = images[permutation]

            res = {
                'states': self.ds['states'][i],
                'actions': self.ds['actions'][i],
                'rollin_obs': images,
                'poses': self.ds['poses'][i],
                'angles': self.ds['angles'][i],
            }

        else:
            rollin_obs = self.ds['rollin_observations'][i]
            rollin_poses = self.ds['rollin_poses'][i]
            rollin_angles = self.ds['rollin_angles'][i]
            rollin_actions = self.ds['rollin_actions'][i]
            rollin_rewards = self.ds['rollin_rewards'][i]
            rollin_next_obs = self.ds['rollin_next_observations'][i]
            rollin_next_poses = self.ds['rollin_next_poses'][i]
            rollin_next_angles = self.ds['rollin_next_angles'][i]
            if self.config['shuffle']:
                permutation = torch.randperm(self.H)
                rollin_obs = rollin_obs[permutation]
                rollin_poses = rollin_poses[permutation]
                rollin_angles = rollin_angles[permutation]
                rollin_actions = rollin_actions[permutation]
                rollin_rewards = rollin_rewards[permutation]
                rollin_next_obs = rollin_next_obs[permutation]
                rollin_next_poses = rollin_next_poses[permutation]
                rollin_next_angles = rollin_next_angles[permutation]

            res = {
                'states': self.ds['states'][i],
                'actions': self.ds['actions'][i],
                'rollin_obs': rollin_obs,
                'rollin_poses': rollin_poses,
                'rollin_angles': rollin_angles,
                'rollin_actions': rollin_actions,
                'rollin_rewards': rollin_rewards,
                'rollin_next_obs': rollin_next_obs,
                'rollin_next_poses': rollin_next_poses,
                'rollin_next_angles': rollin_next_angles,
                'poses': self.ds['poses'][i],
                'angles': self.ds['angles'][i],
            }

        return res


if __name__ == '__main__':
    config = {'shuffle': True}
    n_envs = 1000
    n_hists = 1
    n_samples = 1
    H = 10
    dim = 4
    path_train = f'datasets/trajs_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_train.pkl'
    path_test = f'datasets/trajs_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_test.pkl'
    ds = TrajDataset(path_train, config)
    ds[0]

