from typing import Callable

import numpy as np
import argparse
import d3rlpy
import d4rl
import h5py

from torch import no_grad

import composuite

from d3rlpy.dataset import MDPDataset

GLOBAL_SUBTASK_KWARGS = {
    "has_renderer": False,
    "has_offscreen_renderer": False,
    "reward_shaping": True,
    "use_camera_obs": False,
    "use_task_id_obs": True,
    "env_horizon": 500,
}

parser = argparse.ArgumentParser()
parser.add_argument("--nth", type=int, default=0)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--num-dataset-samples", type=int, default=1000000)
parser.add_argument("--algo", type=str, default="iql")
parser.add_argument("--n-steps", type=int, default=300000)
parser.add_argument("--robot", type=str, default="IIWA")
parser.add_argument("--obj", type=str, default="Dumbbell")
parser.add_argument("--obstacle", type=str, default="None")
parser.add_argument("--objective", type=str, default="Push")
args = parser.parse_args()

robot = args.robot
obj = args.obj
obstacle = args.obstacle
objective = args.objective

# seed d3rlpy
d3rlpy.seed(args.seed)

# prepare dataset
env = composuite.make(robot, obj, obstacle, objective, **GLOBAL_SUBTASK_KWARGS)
path = f"<TO SET>/CompoSuite-offline/expert/{robot}_{obj}_{obstacle}_{objective}/data.hdf5"

with h5py.File(path, "r") as dataset_file:
    observations = dataset_file["observations"][:]
    actions = dataset_file["actions"][:]
    rewards = dataset_file["rewards"][:]
    terminals = dataset_file["terminals"][:]
    episode_terminals = dataset_file["timeouts"][:]

obs = observations
actions = actions
rewards = rewards
terminals = terminals
episode_terminals = episode_terminals

# create a new dataset with only 100000 samples starting from the 100000*nth sample

dataset = MDPDataset(
    observations=obs[
        args.num_dataset_samples * args.nth : args.num_dataset_samples * (args.nth + 1)
    ],
    actions=actions[
        args.num_dataset_samples * args.nth : args.num_dataset_samples * (args.nth + 1)
    ],
    rewards=rewards[
        args.num_dataset_samples * args.nth : args.num_dataset_samples * (args.nth + 1)
    ],
    terminals=terminals[
        args.num_dataset_samples * args.nth : args.num_dataset_samples * (args.nth + 1)
    ],
    episode_terminals=episode_terminals[
        args.num_dataset_samples * args.nth : args.num_dataset_samples * (args.nth + 1)
    ],
)
print(
    "selected dataset size: ",
    len(dataset),
    " from point ",
    args.num_dataset_samples * args.nth,
    "to point ",
    args.num_dataset_samples * (args.nth + 1),
)

# prepare algorithm
if args.algo == "cql":
    algo = d3rlpy.algos.CQL()
elif args.algo == "iql":
    algo = d3rlpy.algos.IQL()
elif args.algo == "sac":
    algo = d3rlpy.algos.SAC()
elif args.algo == "bc":
    algo = d3rlpy.algos.BC()
else:
    raise ValueError("Invalid algorithm")

GLOBAL_STEP_COUNTER = 0

def modified_reset(gym_env):
    original_reset = gym_env.reset

    def reset_wrapper(*args, **kwargs):
        global GLOBAL_STEP_COUNTER
        GLOBAL_STEP_COUNTER = 0

        obs, _ = original_reset(*args, **kwargs)
        return obs

    gym_env.reset = reset_wrapper


def modified_step(gym_env):
    original_step = gym_env.step

    def step_wrapper(*args, **kwargs):
        global GLOBAL_STEP_COUNTER
        GLOBAL_STEP_COUNTER += 1

        obs, rew, done, _, info = original_step(*args, **kwargs)

        if GLOBAL_STEP_COUNTER % 500 == 0:
            info["TimeLimit.truncated"] = True

        return obs, rew, done, info

    gym_env.step = step_wrapper


modified_reset(env)
modified_step(env)

def custom_eval_scorer(
    env, n_trials: int = 10, epsilon: float = 0.0
) -> Callable[..., float]:
    """Returns scorer function of evaluation on environment.

    This function returns scorer function, which is suitable to the standard
    scikit-learn scorer function style.
    The metrics of the scorer function is ideal metrics to evaluate the
    resulted policies.

    .. code-block:: python

        import gym

        from d3rlpy.algos import DQN
        from d3rlpy.metrics.scorer import evaluate_on_environment


        env = gym.make('CartPole-v0')

        scorer = evaluate_on_environment(env)

        cql = CQL()

        mean_episode_return = scorer(cql)


    Args:
        env: gym-styled environment.
        n_trials: the number of trials.
        epsilon: noise factor for epsilon-greedy policy.
        render: flag to render environment.

    Returns:
        scoerer function.


    """

    # for image observation
    def scorer(algo, *args) -> float:
        counter = 0
        episode_rewards = []
        for _ in range(n_trials):
            observation = env.reset()
            cur_rewards = 0
            while True:
                # take action
                if np.random.random() < epsilon:
                    action = env.action_space.sample()
                else:
                    with no_grad():
                        action = algo.predict([observation])[0]

                observation, reward, done, _ = env.step(action)
                cur_rewards += reward
                if done:
                    break

                if counter == 500:
                    break
                counter += 1

            episode_rewards.append(cur_rewards)

        return float(np.mean(episode_rewards))

    return scorer


# train offline
algo.fit(
    dataset,
    n_steps_per_epoch=1000,
    n_steps=args.n_steps,
    scorers={
        "environment": custom_eval_scorer(env, n_trials=1, epsilon=0.0),
    },
    eval_episodes=dataset,
    logdir=f"composuite_logs/{args.algo}/nth_{args.nth}/{args.num_dataset_samples}/{robot}_{obj}_{obstacle}_{objective}",
    save_interval=10,
    with_timestamp=False,
)
