import gym
import numpy as np
from gym.spaces import Box

from sapien.core import Pose
from mani_skill2.envs.pick_and_place.pick_cube import PickCubeEnv
from mani_skill2.utils.registration import register_env

from .img_sources import make_img_source


@register_env("ReachDistracted-v0", max_episode_steps=50, override=True)
class ReachDistracted(PickCubeEnv):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def _initialize_task(self):
        # Fix goal position
        self.goal_pos = np.array([0.0, 0.0, 0.6])
        self.goal_site.set_pose(Pose(self.goal_pos))

    def _initialize_agent(self):
        # Randomize initial position
        # fmt: off
        # EE at [0.615, 0, 0.17]
        qpos = np.array(
            [0.0, np.pi / 8, 0, -np.pi * 5 / 8, 0, np.pi * 3 / 4, np.pi / 4, 0.04, 0.04]
        )
        # fmt: on
        qpos[:-2] += self._episode_rng.normal(0, 0.2, len(qpos) - 2)
        self.agent.reset(qpos)
        self.agent.robot.set_pose(Pose([-0.615, 0, 0]))

    def _initialize_actors(self):
        self.obj_init_pose = Pose([-10, -10, -10])
        self.obj.set_pose(self.obj_init_pose)

    def get_done(self, info, **kwargs):
        # Disable done from task completion
        return False

    def compute_dense_reward(self, info, **kwargs):
        tcp_to_goal = np.linalg.norm(self.tcp.pose.p - self.goal_pos)
        reaching_reward = -5 * np.linalg.norm(tcp_to_goal)
        return reaching_reward


class ManiSkillDistractedWrapper(gym.Wrapper):
    def __init__(self, env):
        super().__init__(env)
        assert env.obs_mode == "rgbd"
        self._observation_space = Box(
            low=0, high=255, shape=(3, 64, 64), dtype=np.uint8
        )

        self._bg_source = make_img_source(
            src_type="video",
            img_shape=(64, 64),
            resource_files="../kinetics-downloader/dataset/train/driving_car/*.mp4",
            total_frames=1000,
            grayscale=False,
        )

    def observation(self, observation):
        obs = observation["image"]["base_camera"]["rgb"]
        # Hardcoded mask for dmc
        mask = np.logical_and((obs[:, :, 0] == 0), (obs[:, :, 1] == 0))
        mask = np.logical_and((obs[:, :, 2] == 0), mask)
        bg = self._bg_source.get_image()
        obs[mask] = bg[mask]
        obs = obs.transpose(2, 0, 1).copy()
        return obs

    def reset(self, **kwargs):
        obs = self.env.reset(reconfigure=True, **kwargs)
        self._bg_source.reset()
        return self.observation(obs)

    def step(self, action):
        obs, reward, done, info = self.env.step(action)
        return self.observation(obs), reward, done, info
