import numpy as np

from utilities.data_utils import atleast_nd


def clip_actions(dataset, clip_to_eps: bool = True, eps: float = 1e-5):
    if clip_to_eps:
        lim = 1 - eps
        dataset["actions"] = np.clip(dataset["actions"], -lim, lim)
    return dataset


def compute_returns(traj):
    episode_return = 0
    for _, _, rew, *_ in traj:
        episode_return += rew
    return episode_return


def split_to_trajs(dataset):
    # print(dataset['observations'].shape)
    dones_float = np.zeros_like(dataset["rewards"])  # truncated and terminal
    for i in range(len(dones_float) - 1):
        if (
            # np.linalg.norm(dataset["observations"][i + 1] - dataset["next_observations"][i])
            # > 1e-6
            # or 
            dataset["dones"][i] == 1.0
            or
            (i+1) % 500 == 0
        ):
            dones_float[i] = 1
        else:
            dones_float[i] = 0
    dones_float[-1] = 1

    trajs = [[]]
    for i in range(dataset["observations"].shape[0]):
        trajs[-1].append(
            (
                dataset["observations"][i],
                dataset["actions"][i],
                dataset["rewards"][i],
                dones_float[i],
                dataset['task_ids'][i],
                dataset["next_observations"][i],
            )
        )
        if dones_float[i] == 1.0 and i + 1 < len(dataset["observations"]):
            trajs.append([])
    
    # print(np.array(trajs).shape)

    return trajs


def pad_trajs_to_dataset(
    trajs,
    max_traj_length: int,
    termination_penalty: float = None,
    include_next_obs: bool = False,
):
    n_trajs = len(trajs)
    print('number of trajectories: ', n_trajs)

    dataset = {}
    obs_dim, act_dim = trajs[0][0][0].shape[0], trajs[0][0][1].shape[0]
    dataset["observations"] = np.zeros((n_trajs, max_traj_length, obs_dim), dtype=np.float32)
    dataset["actions"] = np.zeros((n_trajs, max_traj_length, act_dim), dtype=np.float32)
    dataset["rewards"] = np.zeros((n_trajs, max_traj_length), dtype=np.float32)
    dataset["task_ids"] = np.zeros((n_trajs, max_traj_length), dtype=np.int32)
    dataset["dones"] = np.zeros((n_trajs, max_traj_length), dtype=np.float32)
    dataset["traj_lengths"] = np.zeros((n_trajs,), dtype=np.int32)
    if include_next_obs:
        dataset["next_observations"] = np.zeros((n_trajs, max_traj_length, obs_dim), dtype=np.float32)

    for idx, traj in enumerate(trajs):
        traj_length = len(traj)
        dataset["traj_lengths"][idx] = traj_length
        dataset["observations"][idx, :traj_length] = atleast_nd(
            np.stack([ts[0] for ts in traj], axis=0),
            n=2,
        )
        dataset["actions"][idx, :traj_length] = atleast_nd(
            np.stack([ts[1] for ts in traj], axis=0),
            n=2,
        )
        dataset["rewards"][idx, :traj_length] = np.stack([ts[2] for ts in traj], axis=0)
        dataset["dones"][idx, :traj_length] = np.stack(
            [ts[3] for ts in traj], axis=0
        )
        dataset["task_ids"][idx, :traj_length] = np.stack(
            [ts[4] for ts in traj], axis=0
        )
        if include_next_obs:
            dataset["next_observations"][idx, :traj_length] = atleast_nd(
                np.stack([ts[5] for ts in traj], axis=0),
                n=2,
            )
        # if dataset["terminals"][idx].any() and termination_penalty is not None:
        #     dataset["rewards"][idx, traj_length - 1] += termination_penalty

    return dataset
