import os

from gym import error, spaces
from gym.utils import seeding
import numpy as np
from os import path
import gym

try:
    import mujoco_py
except ImportError as e:
    raise error.DependencyNotInstalled("{}. (HINT: you need to install mujoco_py, and also perform the setup instructions here: https://github.com/openai/mujoco-py/.)".format(e))

from d4rl.kitchen.adept_envs.simulation.sim_robot import MujocoSimRobot, RenderMode



class MujocoEnv(gym.Env):
    """
    This is a simplified version of the gym MujocoEnv class.

    Some differences are:
     - Do not automatically set the observation/action space.
    """
    def __init__(self, model_path, frame_skip, device_id=-1, automatically_set_spaces=False):
        if model_path.startswith("/"):
            fullpath = model_path
        else:
            fullpath = os.path.join(os.path.dirname(__file__), "assets", model_path)
        if not path.exists(fullpath):
            raise IOError("File %s does not exist" % fullpath)
        self.frame_skip = frame_skip
        self.model = mujoco_py.load_model_from_path(fullpath)
        self.sim_robot = MujocoSimRobot(
            fullpath,
            use_dm_backend=True,
            camera_settings={}
        )#mujoco_py.MjSim(self.model)
        self.sim = self.sim_robot.sim
        self.data = self.sim.data
        self.viewer = None

        self.metadata = {
            'render.modes': ['human', 'rgb_array'],
            'video.frames_per_second': int(np.round(1.0 / self.dt))
        }
        if device_id == -1 and 'gpu_id' in os.environ:
            device_id =int(os.environ['gpu_id'])
        self.device_id = device_id
        self.init_qpos = self.sim.data.qpos.ravel().copy()
        self.init_qvel = self.sim.data.qvel.ravel().copy()
        if automatically_set_spaces:
            observation, _reward, done, _info = self.step(np.zeros(self.model.nu))
            assert not done
            self.obs_dim = observation.size

            bounds = self.model.actuator_ctrlrange.copy()
            low = bounds[:, 0]
            high = bounds[:, 1]
            self.action_space = spaces.Box(low=low, high=high)

            high = np.inf*np.ones(self.obs_dim)
            low = -high
            self.observation_space = spaces.Box(low, high)

        self.seed()

    def seed(self, seed=None):
        self.np_random, seed = seeding.np_random(seed)
        return [seed]

    # methods to override:
    # ----------------------------

    def reset_model(self):
        """
        Reset the robot degrees of freedom (qpos and qvel).
        Implement this in each subclass.
        """
        raise NotImplementedError

    def viewer_setup(self):
        """
        This method is called when the viewer is initialized and after every reset
        Optionally implement this method, if you need to tinker with camera position
        and so forth.
        """
        pass

    # -----------------------------

    def reset(self):
        self.sim.reset()
        ob = self.reset_model()
        if self.viewer is not None:
            self.viewer_setup()
        return ob

    def set_state(self, qpos, qvel):
        assert qpos.shape == (self.model.nq,) and qvel.shape == (self.model.nv,)
        state = np.concatenate([qpos, qvel])
        # state = self.sim.get_state()
        # for i in range(self.model.nq):
        #     state.qpos[i] = qpos[i]
        # for i in range(self.model.nv):
        #     state.qvel[i] = qvel[i]
        self.sim.set_state(state)
        self.sim.forward()
        """
             assert qpos.shape == (self.model.nq,) and qvel.shape == (self.model.nv,)
        old_state = self.sim.get_state()
        new_state = mujoco_py.MjSimState(old_state.time, qpos, qvel,
                                         old_state.act, old_state.udd_state)
        self.sim.set_state(new_state)
        self.sim.forward()
        """
   

    @property
    def dt(self):
        return self.model.opt.timestep * self.frame_skip

    def do_simulation(self, ctrl, n_frames=None):
        if n_frames is None:
            n_frames = self.frame_skip
        if self.sim.data.ctrl is not None and ctrl is not None:
            self.sim.data.ctrl[:] = ctrl
        for _ in range(n_frames):
            self.sim.step()

    """
        def render(self, mode='human'):
        if mode == 'rgb_array':
            #self._get_viewer().render()
            # window size used for old mujoco-py:
            width, height = 500, 500
            width, height = 4000, 4000
            width, height = 64, 64
            data = self._get_viewer().read_pixels(width, height, depth=False)
            # original image is upside-down, so flip it
            return data[::-1, :, :]
        else:
            assert False
        #elif mode == 'human':
        #    self._get_viewer().render()
    
    """


    def render(
        self,
        mode="human",
        width=64,
        height=64,
        camera_id=-1,
    ):
        """Renders the environment.

        Args:
            mode: The type of rendering to use.
                - 'human': Renders to a graphical window.
                - 'rgb_array': Returns the RGB image as an np.ndarray.
                - 'depth_array': Returns the depth image as an np.ndarray.
            width: The width of the rendered image. This only affects offscreen
                rendering.
            height: The height of the rendered image. This only affects
                offscreen rendering.
            camera_id: The ID of the camera to use. By default, this is the free
                camera. If specified, only affects offscreen rendering.
        """
        if mode == "human":
            self.sim_robot.renderer.render_to_window()
        elif mode == "rgb_array":
            assert width and height
            return self.sim_robot.renderer.render_offscreen(
                width, height, mode=RenderMode.RGB, camera_id=camera_id
            )
        elif mode == "depth_array":
            assert width and height
            return self.sim_robot.renderer.render_offscreen(
                width, height, mode=RenderMode.DEPTH, camera_id=camera_id
            )
        else:
            raise NotImplementedError(mode)

    def close(self):
        if self.viewer is not None:
            self.viewer.finish()
            self.viewer = None

    def _get_viewer(self):
        if self.viewer is None:
            self.viewer = mujoco_py.MjViewer(self.sim)
            self.viewer_setup()
        return self.viewer

    def get_body_com(self, body_name):
        return self.data.get_body_xpos(body_name)

    def state_vector(self):
        return np.concatenate([
            self.sim.data.qpos.flat,
            self.sim.data.qvel.flat
        ])

    def get_image(self, width=84, height=84, camera_name=None):
        return self.sim.render(
            width=width,
            height=height,
            camera_name=camera_name,
        )

    def initialize_camera(self, init_fctn):
        sim = self.sim
        viewer = mujoco_py.MjRenderContextOffscreen(sim, device_id=self.device_id)
        # viewer = mujoco_py.MjViewer(sim)
        init_fctn(viewer.cam)
        sim.add_render_context(viewer)
