import gym
import numpy as np
from painting_locobot_env.resources.Locobot_interface import LoCoBotInterface
from gym.spaces import Dict, Discrete, Dict
import time

class PaintingLocobotEnv(gym.Env):
    def __init__(self):
        self.action_space = Discrete(5)

        
        # self.obs_space = gym.spaces.box.Box(
        #     low=np.array(np.zeros((128, 128, 3)), dtype=np.float32),
        #     high=np.array(np.full((128, 128, 3), 256), dtype=np.float32))

        self.obs_space = gym.spaces.box.Box(
            low=np.array([0, 0], dtype=np.float32),
            high=np.array([10, 10], dtype=np.float32))
        

        self.state_space = self.obs_space

        self.observation_space = Dict([
            ('observation', self.obs_space),
            ('desired_goal', self.obs_space),
            ('achieved_goal', self.obs_space),
            ('state_observation', self.obs_space),
            ('state_desired_goal', self.obs_space),
            ('state_achieved_goal', self.obs_space),
        ])

        self.locobot = LoCoBotInterface()

        self.init_goals()

        # Getting ROS images is timely. Store the last states to minimize the number of calls
        trash = self.locobot.get_image_rgb()
        time.sleep(2)
        self.current_state = None    
        self.update_state()

        self.threshold = 0.1

        self.timestep = 0


    def init_goals(self):
        self.goal_state = np.array([4, 2])
        base_path = "/home/locobot/Desktop/new_interface/goalrelabel_locobot_fullgcsl/gcsl/envs/locobot/goals/"
        self.goal_image = np.load(base_path + "goal_U_image.npy")

    def step(self, action):
        self.locobot.step(action)
        self.update_state()
        self.timestep += 1
        reward = 0
        done = False

        return self._get_obs(), reward, done, {}

    def go_rest(self):
        self.locobot.go_rest()

    def update_state(self):
        self.current_image = self.locobot.get_image_rgb()

    def reset(self):
        self.timestep = 0
        self.locobot.go_rest()
        self.locobot.reset()
        self.locobot.go_start()

        self.update_state()
        return self._get_obs()
        
    def render(self):
        return self.current_image
        
    def render_image(self):
        return self.current_image

    def seed(self, seed=None): 
        self.np_random, seed = gym.utils.seeding.np_random(seed)
        return [seed]
    
    def interact(self):
        return
    
    def observation(self, state):
        return state
    
    def extract_goal(self, state):
        return state
    
    def _get_obs(self):
        state_obs = np.array(self.locobot.current_pos)
        achieved_state_goal = state_obs.copy()
        intended_state_goal = self.goal_state

        obs = state_obs.copy()
    
        achieved_goal = achieved_state_goal.copy()
        intended_goal = intended_state_goal.copy()
            
        return dict(
            observation = obs,
            desired_goal = intended_goal,
            achieved_goal = achieved_goal,
            state_observation = state_obs,
            state_desired_goal = intended_state_goal,
            state_achieved_goal = achieved_state_goal,
        )
    
    def get_shaped_distance(self, state1, state2):
        return 0
    
    def compute_shaped_distance(self, state1, state2):
        return self.get_shaped_distance(state1, state2)
    
    def sample_goal(self):
        return np.array([4, 2])
        
    def plot_trajectories(self):
        return

    def compute_success(self, state, goal):
        return 0