import os
from random import random, uniform
os.environ['MUJOCO_GL'] = 'egl'
os.environ['EGL_DEVICE_ID'] = '0'


import copy
import pickle as pkl
import sys
import time

import numpy as np
from queue import Queue
import hydra
import torch
import utils
from logger import Logger
from replay_buffer import ReplayBuffer, HindsightExperienceReplayWrapperVer2
from video import VideoRecorder
import matplotlib.pyplot as plt
import seaborn as sns
from hgg.hgg import goal_distance
from visualize.visualize_2d import *
torch.backends.cudnn.benchmark = True

class UniformFeasibleGoalSampler:
    def __init__(self, env_name):        
        self.env_name = env_name        
        if env_name in ['sawyer_peg_pick_and_place']:
            self.LOWER_CONTEXT_BOUNDS = np.array([-0.6, 0.2, 0.01478]) 
            self.UPPER_CONTEXT_BOUNDS = np.array([0.6, 1.0, 0.4])            
        elif env_name in ['sawyer_peg_push']:
            self.LOWER_CONTEXT_BOUNDS = np.array([-0.6, 0.2, 0.01478]) 
            self.UPPER_CONTEXT_BOUNDS = np.array([0.6, 1.0, 0.02])        
        elif env_name in ['Point4WayComplexVer2Maze-v0', 'Point4WayFarmlandMaze-v0']:
            self.LOWER_CONTEXT_BOUNDS = np.array([-18, -18]) 
            self.UPPER_CONTEXT_BOUNDS = np.array([18, 18])
        elif env_name in ['AntMazeComplex2Way-v0']:
            self.LOWER_CONTEXT_BOUNDS = np.array([-6, -10]) 
            self.UPPER_CONTEXT_BOUNDS = np.array([6, 10])
        elif env_name in ['Point2WaySpiralMaze-v0']:
            self.LOWER_CONTEXT_BOUNDS = np.array([-14, -18]) 
            self.UPPER_CONTEXT_BOUNDS = np.array([14, 18])
        else:
            raise NotImplementedError

        if env_name in ['sawyer_peg_pick_and_place', 'sawyer_peg_push']:
            self.margin = None
        else:
            self.margin = 0.5
    
    # only for visualization in this work
    def is_feasible(self, context):
        if self.env_name in ['sawyer_peg_pick_and_place']:
            if not np.all(np.logical_and(self.LOWER_CONTEXT_BOUNDS < context, context <self.UPPER_CONTEXT_BOUNDS)):
                return False            
            else:
                return True
        elif self.env_name in ['sawyer_peg_push']:
            if not np.all(np.logical_and(self.LOWER_CONTEXT_BOUNDS < context, context <self.UPPER_CONTEXT_BOUNDS)):
                return False            
            else:
                return True
        elif self.env_name in ['Point4WayComplexVer2Maze-v0']:                        
            if (context[0] < -17.5) or (context[0] > 17.5):
                return False
            elif (context[1] < -17.5) or (context[1] > 17.5):
                return False
            elif np.all((np.logical_and(-14.5 < context[0], context[0] < -1.5), np.logical_and(-14.5 < context[1], context[1] < -1.5))):
                return False
            elif np.all((np.logical_and(1.5 < context[0], context[0] < 14.5), np.logical_and(-14.5 < context[1], context[1] < -1.5))):
                return False
            elif np.all((np.logical_and(1.5 < context[0], context[0] < 14.5), np.logical_and(1.5 < context[1], context[1] < 14.5))):
                return False
            elif np.all((np.logical_and(-14.5 < context[0], context[0] < -1.5), np.logical_and(1.5 < context[1], context[1] < 14.5))):
                return False
            # last corners
            elif np.all((np.logical_and(-6.5 < context[0], context[0] < -1.5), np.logical_and(-14.5 < context[1], context[1] < -1.5))):
                return False
            elif np.all((np.logical_and(1.5 < context[0], context[0] < 14.5), np.logical_and(-6.5 < context[1], context[1] < -1.5))):
                return False
            elif np.all((np.logical_and(1.5 < context[0], context[0] < 6.5), np.logical_and(1.5 < context[1], context[1] < 14.5))):
                return False
            elif np.all((np.logical_and(-14.5 < context[0], context[0] < -1.5), np.logical_and(1.5 < context[1], context[1] < 6.5))):
                return False
            else:
                return True
        elif self.env_name in ['Point4WayFarmlandMaze-v0']:
            if (context[0] < -17.5) or (context[0] > 17.5):
                return False
            elif (context[1] < -17.5) or (context[1] > 17.5):
                return False
            elif np.all((np.logical_and(-14.5 < context[0], context[0] < -1.5), np.logical_and(-14.5 < context[1], context[1] < -1.5))):
                return False
            elif np.all((np.logical_and(1.5 < context[0], context[0] < 14.5), np.logical_and(-14.5 < context[1], context[1] < -1.5))):
                return False
            elif np.all((np.logical_and(1.5 < context[0], context[0] < 14.5), np.logical_and(1.5 < context[1], context[1] < 14.5))):
                return False
            elif np.all((np.logical_and(-14.5 < context[0], context[0] < -1.5), np.logical_and(1.5 < context[1], context[1] < 14.5))):
                return False
            else:
                return True
        elif self.env_name in ['Point2WaySpiralMaze-v0']:
            if (context[0] < -13.5) or (context[0] > 13.5):
                return False
            elif (context[1] < -17.5) or (context[1] > 17.5):
                return False
            elif np.all((np.logical_and(-13.5 < context[0], context[0] < 10.5), np.logical_and(-14.5 < context[1], context[1] < -9.5))):
                return False
            elif np.all((np.logical_and(5.5 < context[0], context[0] < 10.5), np.logical_and(-14.5 < context[1], context[1] < 6.5))):
                return False
            elif np.all((np.logical_and(-2.5 < context[0], context[0] < 10.5), np.logical_and(1.5 < context[1], context[1] < 6.5))):
                return False
            elif np.all((np.logical_and(-10.5 < context[0], context[0] < 13.5), np.logical_and(9.5 < context[1], context[1] < 14.5))):
                return False
            elif np.all((np.logical_and(-10.5 < context[0], context[0] < -5.5), np.logical_and(-6.5 < context[1], context[1] < 14.5))):
                return False
            elif np.all((np.logical_and(-10.5 < context[0], context[0] < 2.5), np.logical_and(-6.5 < context[1], context[1] < -1.5))):
                return False
            else:
                return True
        elif self.env_name in ['AntMazeComplex2Way-v0']:
            if (context[0] < -5.5) or (context[0] > 5.5):
                return False
            elif (context[1] < -9.5) or (context[1] > 9.5):
                return False            
            elif np.all((np.logical_and(-5.5 < context[0], context[0] < 2.5), np.logical_and(-6.5 < context[1], context[1] < -1.5))):
                return False
            elif np.all((np.logical_and(-2.5 < context[0], context[0] < 5.5), np.logical_and(1.5 < context[1], context[1] < 6.5))):
                return False
            else:
                return True
        else:
            raise NotImplementedError

    def sample(self, num_sample=1, sample_feasible=True):
        dim = self.LOWER_CONTEXT_BOUNDS.shape[-1]

        sample = np.random.uniform(self.LOWER_CONTEXT_BOUNDS, self.UPPER_CONTEXT_BOUNDS, size=(num_sample,dim)).squeeze() #[dim] or [num_sample, dim]
        if sample_feasible:
            assert num_sample==1, 'In current implementation, better to use for loop for multiple feasible samples'
            while not self.is_feasible(sample):
                sample = np.random.uniform(self.LOWER_CONTEXT_BOUNDS, self.UPPER_CONTEXT_BOUNDS, size=(num_sample,dim)).squeeze() #[dim] or [num_sample, dim]
        
        return sample

    

def get_object_states_only_from_goal(env_name, goal):
    if env_name in ['sawyer_door', 'sawyer_peg']:
        return goal[..., 4:7]

    elif env_name == 'tabletop_manipulation':
        raise NotImplementedError
    
    else:
        raise NotImplementedError

def get_original_final_goal(env_name, env = None):
    if env_name in ['sawyer_peg_push']:
        original_final_goal = np.array([-0.3, 0.4, 0.02])
    elif env_name in ['sawyer_peg_pick_and_place']:
        original_final_goal = np.array([-0.3, 0.4, 0.2])    
    else:
        raise NotImplementedError
    return original_final_goal.copy()

def get_final_goals_w_noise(env_name, num_goals, env = None):
    if env_name in ['sawyer_peg_push' ]:
        if env.multi_target:
            assert num_goals%3==0
            final_goal_states = np.tile(np.array([[-0.3, 0.4, 0.02],
                                                    [-0.3, 0.8, 0.02],
                                                    [0.4, 0.4, 0.02]]), (int(num_goals/3),1))
        else:
            final_goal_states = np.tile(np.array([-0.3, 0.4, 0.02]), (num_goals,1))
        noise = np.random.normal(loc=np.zeros_like(final_goal_states), scale=0.05*np.ones_like(final_goal_states))
        noise[2] = 0
        final_goal_states += noise
    elif env_name in ['sawyer_peg_pick_and_place']:
        if env.multi_target:
            final_goal_states = np.tile(np.array([[-0.3, 0.4, 0.2],
                                                    [-0.3, 0.8, 0.2],
                                                    [0.4, 0.4, 0.2]]), (int(num_goals/3),1))
        else:
            final_goal_states = np.tile(np.array([-0.3, 0.4, 0.2]), (num_goals,1))
        final_goal_states += np.random.normal(loc=np.zeros_like(final_goal_states), scale=0.05*np.ones_like(final_goal_states))
    # if multi target goals, outputs are multiplied by the number of target goals
    elif env_name  in ['Point4WayComplexMaze-v0', 'Point4WayFarmlandMaze-v0']:
        assert num_goals%4==0
        final_goal_states = np.tile(np.array([[16., 16.], [-16.0, -16.], [16., -16.], [-16., 16.]]), (int(num_goals/4),1))
        final_goal_states += np.random.normal(loc=np.zeros_like(final_goal_states), scale=0.5*np.ones_like(final_goal_states))
    elif env_name  == 'Point4WayComplexVer2Maze-v0':
        assert num_goals%4==0
        final_goal_states = np.tile(np.array([[8., 16.], [-8., -16.], [16., -8.], [-16., 8.]]), (int(num_goals/4),1))
        final_goal_states += np.random.normal(loc=np.zeros_like(final_goal_states), scale=0.5*np.ones_like(final_goal_states))
    elif env_name  in ['Point2WaySpiralMaze-v0']:
        assert num_goals%2==0
        final_goal_states = np.tile(np.array([[12., 16.], [-12.0, -16.]]), (int(num_goals/2),1))
        final_goal_states += np.random.normal(loc=np.zeros_like(final_goal_states), scale=0.5*np.ones_like(final_goal_states))
    elif env_name  in ['AntMazeComplex2Way-v0']:
        assert num_goals%2==0
        final_goal_states = np.tile(np.array([[4., 8.], [-4.0, -8.]]), (int(num_goals/2),1))
        final_goal_states += np.random.normal(loc=np.zeros_like(final_goal_states), scale=0.5*np.ones_like(final_goal_states))
    else:
        raise NotImplementedError
    
    return final_goal_states, None


max_episode_timesteps_dict = {'sawyer_peg_pick_and_place' : 200,
                              'sawyer_peg_push' : 200,
                              'Point4WayComplexVer2Maze-v0' : 100,
                              'Point4WayFarmlandMaze-v0' : 100,
                              'Point2WaySpiralMaze-v0' : 100,
                              'AntMazeComplex2Way-v0' : 300,
                             }

num_seed_steps_dict = { 'sawyer_peg_pick_and_place' : 2000,
                        'sawyer_peg_push' : 2000,
                        'Point4WayComplexVer2Maze-v0' : 2000,
                        'Point2WaySpiralMaze-v0' : 2000,
                        'Point4WayFarmlandMaze-v0' : 2000, 
                        'AntMazeComplex2Way-v0' : 4000,
                        }

num_random_steps_dict = {'sawyer_peg_pick_and_place' : 2000,
                         'sawyer_peg_push' : 2000,
                         'Point4WayComplexVer2Maze-v0' : 2000,                         
                         'Point4WayFarmlandMaze-v0' : 2000, 
                         'Point2WaySpiralMaze-v0' : 2000,
                         'AntMazeComplex2Way-v0' : 4000,
                        }

randomwalk_random_noise_dict = {'sawyer_peg_pick_and_place' : 0.1,
                                'sawyer_peg_push' : 0.1,
                                'Point4WayComplexVer2Maze-v0' : 2.5,
                                'Point4WayFarmlandMaze-v0' : 2.5, 
                                'Point2WaySpiralMaze-v0' : 2.5,
                                'AntMazeComplex2Way-v0' : 2.5,
                                }

d2c_noise_scale_dict = {'sawyer_peg_pick_and_place' : 0.025,
                         'sawyer_peg_push' : 0.025,
                         'Point4WayComplexVer2Maze-v0' : 0.5,
                         'Point4WayFarmlandMaze-v0' : 0.5, 
                         'Point2WaySpiralMaze-v0' : 0.5,
                         'AntMazeComplex2Way-v0' : 1.0, 
                        }

num_curriculum_dict = {'sawyer_peg_pick_and_place' : 20*3,
                         'sawyer_peg_push' : 20*3,
                         'Point4WayComplexVer2Maze-v0' : 20*4,
                         'Point4WayFarmlandMaze-v0' : 20*4, 
                         'Point2WaySpiralMaze-v0' : 20*2,
                         'AntMazeComplex2Way-v0' : 20*2,
                        }



class Workspace(object):
    def __init__(self, cfg):
        self.work_dir = os.getcwd()
        print(f'workspace: {self.work_dir}')
        self.model_dir = utils.make_dir(self.work_dir, 'model')
        
        self.buffer_dir = utils.make_dir(self.work_dir, 'buffer')

        self.cfg = cfg

        self.logger = Logger(self.work_dir,
                             save_tb=cfg.log_save_tb,
                             log_frequency=cfg.log_frequency_step,
                             action_repeat=cfg.action_repeat,
                             agent='D2C_rl')

        utils.set_seed_everywhere(cfg.seed)
        self.device = torch.device(cfg.device)

        cfg.max_episode_timesteps = max_episode_timesteps_dict[cfg.env]
        cfg.num_seed_steps = num_seed_steps_dict[cfg.env]
        cfg.num_random_steps = num_random_steps_dict[cfg.env]
        cfg.randomwalk_random_noise = randomwalk_random_noise_dict[cfg.env]
        
        assert cfg.done_on_success
        # assert not cfg.consider_done_true_in_critic
        assert cfg.rl_reward_type in ['sparse', 'd2c']
        
        if cfg.env in ['sawyer_peg_push', 'sawyer_peg_pick_and_place']:
            cfg.goal_env=False
            multi_target = True if cfg.multi_target else False
            from envs import sawyer_peg_pick_and_place, sawyer_peg_push, sawyer_peg_pick_and_place_wall, sawyer_peg_push_wall
            if cfg.env =='sawyer_peg_pick_and_place':
                if cfg.sawyer_wall_env:
                    env = sawyer_peg_pick_and_place_wall.SawyerPegPickAndPlaceWallV2(reward_type='sparse', multi_target=multi_target)
                    eval_env = sawyer_peg_pick_and_place_wall.SawyerPegPickAndPlaceWallV2(reward_type='sparse', multi_target=multi_target)
                else:
                    env = sawyer_peg_pick_and_place.SawyerPegPickAndPlaceV2(reward_type='sparse', multi_target=multi_target)
                    eval_env = sawyer_peg_pick_and_place.SawyerPegPickAndPlaceV2(reward_type='sparse', multi_target=multi_target)
            elif cfg.env =='sawyer_peg_push':
                if cfg.sawyer_wall_env:
                    env = sawyer_peg_push_wall.SawyerPegPushWallV2(reward_type='sparse', close_gripper=False, multi_target=multi_target)
                    eval_env = sawyer_peg_push_wall.SawyerPegPushWallV2(reward_type='sparse', close_gripper=False, multi_target=multi_target)
                else:
                    env = sawyer_peg_push.SawyerPegPushV2(reward_type='sparse', close_gripper=False, multi_target=multi_target)
                    eval_env = sawyer_peg_push.SawyerPegPushV2(reward_type='sparse', close_gripper=False, multi_target=multi_target)
            
            from gym.wrappers.time_limit import TimeLimit
            env = TimeLimit(env, max_episode_steps=cfg.max_episode_timesteps)
            eval_env = TimeLimit(eval_env, max_episode_steps=cfg.max_episode_timesteps)
            

            if cfg.use_residual_randomwalk:
                from env_utils import ResidualGoalWrapper
                env = ResidualGoalWrapper(env, env_name = cfg.env)
                eval_env = ResidualGoalWrapper(eval_env, env_name = cfg.env)
            
            if cfg.sparse_reward_type == 'negative': 
                reward_offset = -1.0 
            elif cfg.sparse_reward_type == 'positive':
                reward_offset = 0.0
                
                                       
            from env_utils import StateWrapper, DoneOnSuccessWrapper
            if cfg.done_on_success:
                relative_goal_env = False
                residual_goal_env = True if cfg.use_residual_randomwalk else False
                env = DoneOnSuccessWrapper(env, relative_goal_env = (relative_goal_env or residual_goal_env), reward_offset=reward_offset, earl_env = False)
                eval_env = DoneOnSuccessWrapper(eval_env, relative_goal_env = (relative_goal_env or residual_goal_env), reward_offset=reward_offset, earl_env = False)

            from env_utils import WraptoGoalEnv
            self.env = StateWrapper(WraptoGoalEnv(env, env_name = cfg.env))
            self.eval_env = StateWrapper(WraptoGoalEnv(eval_env, env_name = cfg.env))

            obs_spec = self.env.observation_spec()
            action_spec = self.env.action_spec()
        
        elif cfg.goal_env: # e.g. Fetch, Ant
            import gym            
            from env_utils import StateWrapper, HERGoalEnvWrapper, DoneOnSuccessWrapper, ResidualGoalWrapper             
            if cfg.env in ['AntMazeComplex2Way-v0']:
                from gym.wrappers.time_limit import TimeLimit
                from envs.AntEnv.envs.antenv import EnvWithGoal
                from envs.AntEnv.envs.antenv.create_maze_env import create_maze_env                                              
                self.env = TimeLimit(EnvWithGoal(create_maze_env(cfg.env, cfg.seed, env_path = cfg.env_path), cfg.env), max_episode_steps=cfg.max_episode_timesteps)
                self.eval_env = TimeLimit(EnvWithGoal(create_maze_env(cfg.env, cfg.seed, env_path = cfg.env_path), cfg.env), max_episode_steps=cfg.max_episode_timesteps)
                
                self.env.set_attribute(evaluate=False, distance_threshold=1.0, horizon=cfg.max_episode_timesteps, early_stop=False)
                self.eval_env.set_attribute(evaluate=True, distance_threshold=1.0, horizon=cfg.max_episode_timesteps, early_stop=False)


                if cfg.use_residual_randomwalk:
                    self.env = ResidualGoalWrapper(self.env, env_name = cfg.env)
                    self.eval_env = ResidualGoalWrapper(self.eval_env, env_name = cfg.env)
            elif cfg.env in ['Point2WaySpiralMaze-v0', 'Point4WayComplexVer2Maze-v0','Point4WayFarmlandMaze-v0']:
                from gym.wrappers.time_limit import TimeLimit
                import mujoco_maze                                             
                self.env = TimeLimit(gym.make(cfg.env), max_episode_steps=cfg.max_episode_timesteps)
                self.eval_env = TimeLimit(gym.make(cfg.env), max_episode_steps=cfg.max_episode_timesteps)

                if cfg.use_residual_randomwalk:
                    self.env = ResidualGoalWrapper(self.env, env_name = cfg.env)
                    self.eval_env = ResidualGoalWrapper(self.eval_env, env_name = cfg.env)

            else:
                self.env = gym.make(cfg.env)
                self.eval_env = gym.make(cfg.env)

            if cfg.sparse_reward_type == 'negative':
                reward_offset = 0.0 
            elif cfg.sparse_reward_type == 'positive':
                reward_offset = 1.0
                    
            if cfg.done_on_success:
                relative_goal_env = False
                residual_goal_env = True if cfg.use_residual_randomwalk else False
                self.env = DoneOnSuccessWrapper(self.env, reward_offset=reward_offset, relative_goal_env = (relative_goal_env or residual_goal_env))
                self.eval_env = DoneOnSuccessWrapper(self.eval_env, reward_offset=reward_offset, relative_goal_env = (relative_goal_env or residual_goal_env))
            
            
            self.env= StateWrapper(HERGoalEnvWrapper(self.env, env_name= cfg.env))
            self.eval_env= StateWrapper(HERGoalEnvWrapper(self.eval_env, env_name= cfg.env))
    
            obs_spec = self.env.observation_spec()
            action_spec = self.env.action_spec()
            
        cfg.agent.action_shape = action_spec.shape
        cfg.agent.action_range = [
            float(action_spec.low.min()),
            float(action_spec.high.max())
        ]
        
        self.max_episode_timesteps = cfg.max_episode_timesteps
        if cfg.use_d2c:            
            if cfg.d2c_kwargs.goal_condition:                
                cfg.d2c_cfg.feature_dim = self.env.goal_dim*2
            else:
                cfg.d2c_cfg.feature_dim = self.env.goal_dim
            assert cfg.use_uncertainty_for_randomwalk == 'd2c'
            assert 'd2c' in cfg.hgg_cost_type
            cfg.d2c_kwargs.noise_scale = d2c_noise_scale_dict[cfg.env]
        
        
        else:            
            # NOTE : just temporary code for running default hgg
            cfg.d2c_cfg = None
            cfg.d2c_kwargs.noise_scale = d2c_noise_scale_dict[cfg.env]



        cfg.agent.critic_target_tau = 0.01

            
        if cfg.env in ['sawyer_peg_push', 'sawyer_peg_pick_and_place']:
            cfg.critic.repr_dim = self.env.obs_dim + self.env.goal_dim # [obs(ag), dg]
            cfg.actor.repr_dim = self.env.obs_dim + self.env.goal_dim # [obs(ag), dg]
        
        else:
            cfg.critic.repr_dim = self.env.obs_dim + self.env.goal_dim*2 # [obs, ag, dg]
            cfg.actor.repr_dim = self.env.obs_dim + self.env.goal_dim*2 # [obs, ag, dg]
            

        
        cfg.agent.d2c_feature_dim = self.env.goal_dim
        cfg.agent.goal_dim = self.env.goal_dim

        cfg.agent.obs_shape = obs_spec.shape
        # exploration agent uses intrinsic reward
        self.expl_agent = hydra.utils.instantiate(cfg.agent)
        
        
        buffer_obs_spec = obs_spec.shape        
        self.expl_buffer = ReplayBuffer(buffer_obs_spec, action_spec.shape,
                                        cfg.replay_buffer_capacity,
                                        self.device)
        

        n_sampled_goal = 4
        self.randomwalk_buffer = None
        if cfg.use_residual_randomwalk:
            self.randomwalk_buffer = ReplayBuffer(buffer_obs_spec, action_spec.shape,
                                        cfg.randomwalk_buffer_capacity,
                                        self.device)
                
            self.randomwalk_buffer = HindsightExperienceReplayWrapperVer2(self.randomwalk_buffer, 
                                                                n_sampled_goal=n_sampled_goal, 
                                                                wrapped_env=self.env,
                                                                env_name = cfg.env,
                                                                consider_done_true = cfg.done_on_success,
                                                                )

       
        
        self.expl_buffer = HindsightExperienceReplayWrapperVer2(self.expl_buffer, 
                                                            n_sampled_goal=n_sampled_goal, 
                                                            wrapped_env=self.env,
                                                            env_name = cfg.env,
                                                            consider_done_true = cfg.done_on_success,
                                                            
                                                            )
        
        
        if cfg.use_hgg:            
            from hgg.hgg import TrajectoryPool, MatchSampler
            cfg.hgg_kwargs.match_sampler_kwargs.num_episodes = num_curriculum_dict[cfg.env]            
            
            self.hgg_achieved_trajectory_pool = TrajectoryPool(**cfg.hgg_kwargs.trajectory_pool_kwargs)
            self.hgg_sampler = MatchSampler(goal_env=self.eval_env, 
                                            goal_eval_env = self.eval_env, 
                                            env_name=cfg.env,
                                            achieved_trajectory_pool = self.hgg_achieved_trajectory_pool,
                                            agent = self.expl_agent,
                                            **cfg.hgg_kwargs.match_sampler_kwargs
                                            )
            if 'vf' in self.hgg_sampler.cost_type:                
                self.hgg_sampler.set_networks(critic=self.expl_agent.critic, policy=self.expl_agent.actor)
            

            from hgg.hgg import SimpleBipartiteMatching
            self.bipartite_matching = SimpleBipartiteMatching(env_name=cfg.env, 
                                                              num_episodes=cfg.hgg_kwargs.match_sampler_kwargs.num_episodes,
                                                              goal_dim = self.env.goal_dim,
                                                              hgg_gcc_path=cfg.hgg_kwargs.match_sampler_kwargs.hgg_gcc_path,
                                                              )

        self.eval_video_recorder = VideoRecorder(self.work_dir if cfg.save_video else None, dmc_env=False, env_name=cfg.env)
        self.train_video_recorder = VideoRecorder(self.work_dir if cfg.save_video else None, dmc_env=False, env_name=cfg.env)
        self.train_video_recorder.init(enabled=False)
        self.step = 0
        
        
        self.uniform_goal_sampler =  UniformFeasibleGoalSampler(env_name=cfg.env, 
                                                                )


    def get_agent(self):                
        return self.expl_agent
        

    def get_buffer(self):                
        return self.expl_buffer
        

    def evaluate(self, eval_uniform_goal=False):
        uniform_goal=False
        repeat = 2 if eval_uniform_goal else 1
                        
        for r in range(repeat):            
            uniform_goal = True if r==1 else False

            avg_episode_reward = 0
            avg_episode_success_rate = 0       
            avg_episode_distance_to_goal = 0
            for episode in range(self.cfg.num_eval_episodes):
                observes = []
                if uniform_goal:
                    sampled_goal = self.uniform_goal_sampler.sample()                    
                    obs = self.eval_env.reset(goal = sampled_goal)
                else:
                    obs = self.eval_env.reset() 
                    
                observes.append(obs)
                self.eval_video_recorder.init(enabled=(episode==0))
                episode_reward = 0
                episode_step = 0
                done = False
                while not done:                
                    agent = self.get_agent()
                    
                    with utils.eval_mode(agent):
                        action = agent.act(obs, spec = self.eval_env.action_spec(), sample=False)
                    next_obs, reward, done, info = self.eval_env.step(action)
                    self.eval_video_recorder.record(self.eval_env)
                    episode_reward += reward
                    episode_step += 1
                    obs = next_obs
                    
                    if self.cfg.use_residual_randomwalk:
                        if ((episode_step) % self.max_episode_timesteps == 0) or info.get('is_current_goal_success'):
                            done = True

                    observes.append(obs)
                    

            
                if self.eval_env.is_successful(obs):
                    avg_episode_success_rate+=1.0
            
                    
                avg_episode_reward += episode_reward
                
                temp_obs_dict = self.eval_env.convert_obs_to_dict(obs)
                
                avg_episode_distance_to_goal += np.linalg.norm(temp_obs_dict['achieved_goal']-temp_obs_dict['desired_goal'], axis =-1)
                
                if uniform_goal:
                    self.eval_video_recorder.save(f'uniform_goal_{self.step}.mp4')
                else:
                    self.eval_video_recorder.save(f'{self.step}.mp4')
            
            avg_episode_reward /= self.cfg.num_eval_episodes
            avg_episode_success_rate = avg_episode_success_rate/self.cfg.num_eval_episodes
            avg_episode_distance_to_goal /= self.cfg.num_eval_episodes

            if uniform_goal:
                # in maze_task envs, once reset(goal=~) is used, then original final goals are removed..
                if self.cfg.env in ['Point2WaySpiralMaze-v0', 'AntMazeComplex2Way-v0', 'Point4WayComplexVer2Maze-v0','Point4WayFarmlandMaze-v0']:
                    self.eval_env.reset_to_original_goals()
                else:
                    self.eval_env.reset(goal = get_original_final_goal(self.cfg.env))

                self.logger.log('eval/episode_reward_uniform_goal', avg_episode_reward, self.step)
                self.logger.log('eval/episode_success_rate_uniform_goal', avg_episode_success_rate, self.step)
                self.logger.log('eval/episode_distance_to_goal_uniform_goal', avg_episode_distance_to_goal, self.step)
            else:
                self.logger.log('eval/episode_reward', avg_episode_reward, self.step)
                self.logger.log('eval/episode_success_rate', avg_episode_success_rate, self.step)
                self.logger.log('eval/episode_distance_to_goal', avg_episode_distance_to_goal, self.step)
                

        self.logger.dump(self.step, ty='eval')
        
    

    def run(self):        
        self._run()
    
    def _run(self):        
        episode, episode_reward, episode_step = 0, 0, 0
        start_time = time.time()        
        recent_episode_success = Queue(20)

        if self.cfg.use_hgg:
            recent_sampled_goals = Queue(num_curriculum_dict[self.cfg.env]) # self.cfg.hgg_kwargs.match_sampler_kwargs.num_episodes
            

        previous_goals = None
        done = True
        info = {}
        
        if self.cfg.use_d2c:
            final_goal_images, final_goal_corresponding_states = get_final_goals_w_noise(self.cfg.env, num_curriculum_dict[self.cfg.env], env=self.eval_env) 
            agent = self.get_agent()
            agent.final_goal_states = final_goal_images.copy()
        
        if self.cfg.use_hgg:
            temp_obs = self.eval_env.reset()
            
            recent_sampled_goals.put(self.eval_env.convert_obs_to_dict(temp_obs)['achieved_goal'].copy())

            initial_goals = []
            desired_goals, desired_goals_corresponding_states = get_final_goals_w_noise(self.cfg.env, num_curriculum_dict[self.cfg.env], env=self.eval_env) # [num_target_goal*num_epi, dim]
            
            desired_goals_corresponding_states = None

            for i in range(len(desired_goals)):
                temp_obs = self.eval_env.convert_obs_to_dict(self.eval_env.reset())                
                goal_a = temp_obs['achieved_goal'].copy()
                initial_goals.append(goal_a.copy())
            
            assert len(desired_goals)==len(initial_goals)
            # self.hgg_sampler.length = len(desired_goals)

        while self.step <= self.cfg.num_train_steps:
            
            if done:
                if self.step > 0:
                    if recent_episode_success.full():
                        recent_episode_success.get()
                    
                    recent_episode_success.put(float(self.env.is_successful(obs)))


                    # hgg update
                    if self.cfg.use_hgg :
                        if episode % self.cfg.hgg_kwargs.hgg_sampler_update_frequency ==0 :                            
                            hgg_start_time = time.time()
                            hgg_sampler = self.hgg_sampler
                            hgg_sampler.update(initial_goals, desired_goals)
                            

                self.train_video_recorder.save(f'train_episode_{episode-1}.mp4')                
                if self.step > 0:
                    fps = episode_step / (time.time() - start_time)
                    self.logger.log('train/fps', fps, self.step)
                    start_time = time.time()
                    self.logger.log('train/episode_reward', episode_reward, self.step)
                    self.logger.log('train/episode', episode, self.step)
                
                if self.cfg.use_hgg:                    
                    hgg_sampler = self.hgg_sampler
                    n_iter = 0
                    while True:
                        
                        sampled_goal = hgg_sampler.sample(np.random.randint(len(hgg_sampler.pool))).copy()
                        obs = self.env.reset(goal = sampled_goal)
                        
                        if not self.env.is_successful(obs):
                            break
                        n_iter +=1
                        if n_iter==10:
                            break

                    if recent_sampled_goals.full():
                        recent_sampled_goals.get()
                    recent_sampled_goals.put(sampled_goal)
                    
                    assert (sampled_goal == self.env.goal.copy()).all()
                
                
                else:
                    agent = self.get_agent()
                    obs = self.env.reset()
                
                final_goal = self.env.goal.copy()                
                
                if self.cfg.use_hgg:
                    
                    if not self.cfg.multi_target:
                        self.logger.log('train/episode_finalgoal_dist', np.linalg.norm(final_goal), self.step)
                        original_final_goal = get_original_final_goal(self.cfg.env)
                        self.logger.log('train/episode_dist_from_curr_g_to_example_g', np.linalg.norm(final_goal-original_final_goal), self.step)                                        
                        sampled_goals_for_log = np.array(recent_sampled_goals.queue)
                        self.logger.log('train/average_dist_from_curr_g_to_example_g', np.linalg.norm(original_final_goal[None, :]-sampled_goals_for_log, axis =-1).mean(), self.step)
                    
                    else: # just euclidean bipartite matching for computing distance to goals
                        sampled_goals_for_log = self.hgg_sampler.pool.copy() # [hgg.num_episodes(20), dim]

                        achieved_pool = []
                        for data in sampled_goals_for_log:
                            achieved_pool.append(data[None, :])
                        
                        self.bipartite_matching.update(achieved_pool=achieved_pool, desired_goals=desired_goals)
                        
                        self.logger.log('train/average_dist_from_curr_g_to_example_g', self.bipartite_matching.total_cost.mean(0), self.step)
                        
                        
                        
                self.train_video_recorder.init(enabled=False)
                
                if 'Point' in self.cfg.env:
                    if episode <= 1000: # for dense visualize
                        curr_goal_save_freq = 1 
                    else:
                        curr_goal_save_freq = 3 
                else:
                    curr_goal_save_freq = 25

                if self.cfg.use_hgg and episode % curr_goal_save_freq == 0 :
                    sampled_goals_for_vis = np.array(recent_sampled_goals.queue) 
                    fig = plt.figure()
                    sns.set_style("darkgrid")
                    ax1 = fig.add_subplot(1,1,1)                    
                    ax1.scatter(sampled_goals_for_vis[:, 0], sampled_goals_for_vis[:, 1])
                    if self.cfg.env in ['sawyer_peg_push','sawyer_peg_pick_and_place']:
                        plt.xlim(-0.6,0.6)    
                        plt.ylim(0.2,1.0)
                    elif self.cfg.env in ['Point4WayComplexVer2Maze-v0','Point4WayFarmlandMaze-v0']:
                        plt.xlim(-18,18)    
                        plt.ylim(-18,18)                        
                    elif self.cfg.env in ['Point2WaySpiralMaze-v0']:
                        plt.xlim(-14, 14)
                        plt.ylim(-18, 18)
                    elif self.cfg.env in ['AntMazeComplex2Way-v0']:
                        plt.xlim(-6, 6)
                        plt.ylim(-10,10)
                    else:
                        raise NotImplementedError
                    plt.savefig(self.train_video_recorder.save_dir+'/train_curr_goals_episode_'+str(episode)+'.jpg')
                    plt.close()
                    with open(self.train_video_recorder.save_dir+'/train_curr_goals_episode_'+str(episode)+'.pkl', 'wb') as f:
                        pkl.dump(sampled_goals_for_vis, f)


                if episode % self.cfg.train_episode_video_freq == 0 or episode in [5, 25,50,75,100]:             
                    self.train_video_recorder.init(enabled=True)
                    
                    visualize_num_iter = 0
                    scatter_states = self.env.convert_obs_to_dict(obs.copy())['achieved_goal'][None, :]
                    
                    for k in range(visualize_num_iter+1):                                                                                                
                        if self.cfg.use_d2c:
                            visualize_d2c_all_together(agent=agent,                                                
                                            scatter_states = scatter_states.squeeze(),                                               
                                            env_name = self.cfg.env,                                                
                                            savedir_w_name = self.train_video_recorder.save_dir + '/d2c_prob_visualize_train_episode_'+str(episode)+'_s'+str(k),
                                            device=self.device,
                                            multi_target=self.cfg.multi_target,
                                            env=self.eval_env,                                            
                                            )
                            visualize_d2c_all_together(agent=agent,                                                
                                            scatter_states = scatter_states.squeeze(),                                               
                                            env_name = self.cfg.env,                                                
                                            savedir_w_name = self.train_video_recorder.save_dir + '/d2c_prob_visualize_train_rand_goal_'+str(episode)+'_s'+str(k),
                                            device=self.device,
                                            uniform_goal_sampler=self.uniform_goal_sampler,
                                            multi_target=self.cfg.multi_target,
                                            env=self.eval_env,
                                            )
                            
                        if 'vf' in self.cfg.hgg_cost_type:
                            
                            init_state_obs = self.env.convert_obs_to_dict(obs.copy())['observation']
                            
                            if self.cfg.env in ['sawyer_peg_push', 'sawyer_peg_pick_and_place']:
                                temp_initial_state = init_state_obs
                            else:
                                init_state_ag = self.env.convert_obs_to_dict(obs.copy())['achieved_goal']
                                temp_initial_state = np.concatenate([init_state_obs, init_state_ag], axis=-1)

                            visualize_vf(agent=agent,        
                                        initial_state = temp_initial_state,
                                        scatter_states = scatter_states.squeeze(),                                               
                                        env_name = self.cfg.env,                                                
                                        savedir_w_name = self.train_video_recorder.save_dir + '/vf_visualize_train_episode_'+str(episode)+'_s'+str(k),
                                        device= self.device,
                                        )
                episode_reward = 0
                episode_step = 0
                episode += 1
                episode_observes = [obs]
                
                self.logger.log('train/episode', episode, self.step)

            agent = self.get_agent()
            replay_buffer = self.get_buffer()
            # evaluate agent periodically
            if self.step % self.cfg.eval_frequency == 0:
                print('eval started...')
                self.logger.log('eval/episode', episode - 1, self.step)
                self.evaluate(eval_uniform_goal=False)                

            
                if self.cfg.use_residual_randomwalk and (self.randomwalk_buffer.idx > 128 or self.randomwalk_buffer.full):
                
                    temp_obs, _, _, _, _, _ = self.randomwalk_buffer.sample_without_relabeling(128, agent.discount, sample_only_state = False)
                    temp_obs = temp_obs.detach().cpu().numpy()
                    temp_obs_dict = self.env.convert_obs_to_dict(temp_obs)
                    
                    temp_dg = temp_obs_dict['desired_goal']
                    temp_ag = temp_obs_dict['achieved_goal']
                    
                    fig = plt.figure()
                    sns.set_style("darkgrid")
                    
                    ax1 = fig.add_subplot(1,1,1)                                    
                    ax1.scatter(temp_dg[:, 0], temp_dg[:, 1], label = 'goals')
                    ax1.scatter(temp_ag[:, 0], temp_ag[:, 1], label = 'achieved states', color = 'red')
                            
                    if self.cfg.env in ['sawyer_peg_push','sawyer_peg_pick_and_place']:
                        x_min, x_max = -0.6, 0.6
                        y_min, y_max = 0.2, 1.0
                    elif self.cfg.env in ['Point4WayComplexVer2Maze-v0','Point4WayFarmlandMaze-v0']:
                        x_min, x_max = -18,18
                        y_min, y_max = -18,18
                    elif self.cfg.env in ['Point2WaySpiralMaze-v0']:
                        x_min, x_max = -14,14
                        y_min, y_max = -18,18
                    elif self.cfg.env in ['AntMazeComplex2Way-v0']:
                        x_min, x_max = -6,6
                        y_min, y_max = -10,10
                    else:
                        raise NotImplementedError
                    plt.xlim(x_min,x_max)    
                    plt.ylim(y_min,y_max)
            
                    ax1.legend(loc ="best") # 'upper right' # , prop={'size': 20}          
                    plt.savefig(self.eval_video_recorder.save_dir+'/randomwalk_goalandstates_'+str(self.step)+'.jpg')
                    plt.close()
                

            # save agent periodically
            if self.cfg.save_model and self.step % self.cfg.save_frequency == 0:
                utils.save(
                    self.expl_agent,
                    os.path.join(self.model_dir, f'expl_agent_{self.step}.pt'))                
            if self.cfg.save_buffer and (self.step % self.cfg.buffer_save_frequency == 0) :
                utils.save(self.expl_buffer.replay_buffer, os.path.join(self.buffer_dir, f'buffer_{self.step}.pt'))                
                if self.cfg.use_residual_randomwalk:
                    utils.save(self.randomwalk_buffer.replay_buffer, os.path.join(self.buffer_dir, f'randomwalk_buffer_{self.step}.pt'))
            
                if self.cfg.use_hgg:
                    utils.save(self.hgg_achieved_trajectory_pool,  os.path.join(self.buffer_dir, f'hgg_achieved_trajectory_pool_{self.step}.pt'))
                

            # sample action for data collection
            if self.step < self.cfg.num_random_steps:
                spec = self.env.action_spec()                
                action = np.random.uniform(spec.low, spec.high,
                                        spec.shape)
            elif self.cfg.use_residual_randomwalk and self.cfg.randomwalk_method == 'rand_action':
                assert self.cfg.self.env.is_residual_goal
                spec = self.env.action_spec()                
                action = np.random.uniform(spec.low, spec.high,
                                        spec.shape)
            else: 
                with utils.eval_mode(agent):
                    action = agent.act(obs, spec = self.env.action_spec(), sample=True)
            
            start = time.time()
            logging_dict = agent.update(replay_buffer, self.randomwalk_buffer, self.step, self.env, uniform_goal_sampler = self.uniform_goal_sampler)
            
            if self.step % self.cfg.logging_frequency== 0:                
                if logging_dict is not None: # when step = 0                                        
                    for key, val in logging_dict.items():
                        self.logger.log('train/'+key, val, self.step)
            
           
            next_obs, reward, done, info = self.env.step(action)
            
            episode_reward += reward
            episode_observes.append(next_obs)
            
            last_timestep = True if (episode_step+1) % self.max_episode_timesteps == 0 or done else False


            self.train_video_recorder.record(self.env)



            if self.cfg.use_residual_randomwalk:
                if self.env.is_residual_goal:
                    self.randomwalk_buffer.add(obs, action, reward, next_obs, info.get('is_current_goal_success'), last_timestep)
                else:
                    replay_buffer.add(obs, action, reward, next_obs, info.get('is_current_goal_success'), last_timestep)
                    
            else:
                replay_buffer.add(obs, action, reward, next_obs, done, last_timestep)
                
                
            if last_timestep:
                # replay_buffer.add_trajectory(episode_observes)
                replay_buffer.store_episode()
                
                if self.randomwalk_buffer is not None:
                    self.randomwalk_buffer.store_episode()
                
                if self.randomwalk_buffer is not None:
                    if (not replay_buffer.full) and (not self.randomwalk_buffer.full):
                        assert self.step+1 == self.randomwalk_buffer.idx + replay_buffer.idx
                else:
                    if not replay_buffer.full:
                        assert self.step+1 == replay_buffer.idx

                if self.cfg.use_hgg:
                    temp_episode_observes = copy.deepcopy(episode_observes)
                    temp_episode_ag = []
                    # NOTE : should it be [obs, ag] ?
                    if 'd2c' in self.hgg_sampler.cost_type:
                        temp_episode_init_obs = self.eval_env.convert_obs_to_dict(temp_episode_observes[0])['observation']
                        temp_episode_init_ag = self.eval_env.convert_obs_to_dict(temp_episode_observes[0])['achieved_goal']
                        if self.cfg.env in ['sawyer_peg_push', 'sawyer_peg_pick_and_place']:
                            temp_episode_init = temp_episode_init_obs # for bias computing
                        else:
                            temp_episode_init = np.concatenate([temp_episode_init_obs, temp_episode_init_ag], axis=-1) # for bias computing
                    else:    
                        raise NotImplementedError
                        

                    for k in range(len(temp_episode_observes)):
                        temp_episode_ag.append(self.eval_env.convert_obs_to_dict(temp_episode_observes[k])['achieved_goal'])
                    
                    if getattr(self.env, 'full_state_goal', False):
                        raise NotImplementedError("You should modify the code when full_state_goal (should address achieved_goal to compute goal distance below)")


                    achieved_trajectories = [np.array(temp_episode_ag)] # list of [ts, dim]
                    achieved_init_states = [temp_episode_init] # list of [ts(1), dim]

                    selection_trajectory_idx = {}
                    for i in range(len(achieved_trajectories)):                                                 
                        # full state achieved_goal
                        if self.cfg.env in ['AntMazeComplex2Way-v0', 'Point2WaySpiralMaze-v0',  'Point4WayComplexVer2Maze-v0','Point4WayFarmlandMaze-v0']:
                            threshold = 0.2
                        elif self.cfg.env in ['sawyer_peg_push','sawyer_peg_pick_and_place']:
                            threshold = 0.02
                        else:
                            raise NotImplementedError
                        if goal_distance(achieved_trajectories[i][0], achieved_trajectories[i][-1])>threshold: # if there is a difference btw first and last timestep ?
                            selection_trajectory_idx[i] = True
                    
                    hgg_achieved_trajectory_pool = self.hgg_achieved_trajectory_pool
                    for idx in selection_trajectory_idx.keys():
                        hgg_achieved_trajectory_pool.insert(achieved_trajectories[idx].copy(), achieved_init_states[idx].copy())
                        
                    
                    

            obs = next_obs
            episode_step += 1
            self.step += 1
            
            if self.cfg.use_residual_randomwalk:
                if self.cfg.use_d2c:
                    min_step = self.get_agent().d2c_batch_size

                if self.env.is_residual_goal:
                    if (self.env.residual_goalstep % 10 == 0) or info.get('is_current_goal_success'):
                        if (self.cfg.use_uncertainty_for_randomwalk not in [None, 'none', 'None']) and self.step > min_step:
                            if self.get_agent().use_d2c:
                                if self.get_agent().d2c_gcrl:
                                    dg = self.env.convert_obs_to_dict(obs)['desired_goal']
                                else:
                                    raise NotImplementedError
                            else:
                                dg = None
                            residual_goal = self.get_agent().sample_randomwalk_goals(obs = obs, ag = self.env.convert_obs_to_dict(obs)['achieved_goal'], \
                                episode = episode, env=self.env, replay_buffer = self.expl_buffer, \
                                num_candidate = self.cfg.randomwalk_num_candidate, random_noise = self.cfg.randomwalk_random_noise, \
                                uncertainty_mode = self.cfg.use_uncertainty_for_randomwalk, dg = dg)
                        else:
                            noise = np.random.uniform(low=-self.cfg.randomwalk_random_noise, high=self.cfg.randomwalk_random_noise, size=self.env.goal_dim)
                            
                            if self.cfg.env in ['sawyer_peg_pick_and_place']:
                                assert self.cfg.randomwalk_random_noise <= 0.2
                                pass
                            elif self.cfg.env in ['sawyer_peg_push']:
                                assert self.cfg.randomwalk_random_noise <= 0.2
                                noise[2] = 0
                            residual_goal = self.env.convert_obs_to_dict(obs)['achieved_goal'] + noise
                            
                        self.env.reset_goal(residual_goal)
                        obs[-self.env.goal_dim:] = residual_goal.copy()
                else:
                    if info.get('is_current_goal_success'): #succeed original goal
                        self.env.original_goal_success = True
                        if (self.cfg.use_uncertainty_for_randomwalk not in [None, 'none', 'None']) and self.step > min_step:
                            if self.get_agent().use_d2c:
                                if self.get_agent().d2c_gcrl:
                                    dg = self.env.convert_obs_to_dict(obs)['desired_goal']
                                else:
                                    raise NotImplementedError
                            else:
                                dg = None
                            residual_goal = self.get_agent().sample_randomwalk_goals(obs = obs, ag = self.env.convert_obs_to_dict(obs)['achieved_goal'], \
                                episode = episode, env=self.env, replay_buffer = self.expl_buffer, \
                                num_candidate = self.cfg.randomwalk_num_candidate, random_noise = self.cfg.randomwalk_random_noise, \
                                uncertainty_mode = self.cfg.use_uncertainty_for_randomwalk, dg = dg)
                        else:
                            noise = np.random.uniform(low=-self.cfg.randomwalk_random_noise, high=self.cfg.randomwalk_random_noise, size=self.env.goal_dim)

                            if self.cfg.env in ['sawyer_peg_pick_and_place']:
                                assert self.cfg.randomwalk_random_noise <= 0.2
                                pass
                            elif self.cfg.env in ['sawyer_peg_push']:
                                assert self.cfg.randomwalk_random_noise <= 0.2
                                noise[2] = 0
                            residual_goal = self.env.convert_obs_to_dict(obs)['achieved_goal'] + noise
                        self.env.reset_goal(residual_goal)
                        obs[-self.env.goal_dim:] = residual_goal.copy()
                if (episode_step) % self.max_episode_timesteps == 0: #done only horizon ends
                    done = True
                    info['is_success'] = self.env.original_goal_success


                    

@hydra.main(config_path='./config', config_name='config_D2C.yaml')
def main(cfg):
    import os
    os.environ['HYDRA_FULL_ERROR'] = str(1)
    from d2c_train import Workspace as W
    

    workspace = W(cfg)
    workspace.run()


if __name__ == '__main__':
    main()

