import numpy as np
import torch
import gym
import argparse
import os
import sys

from utils import ReplayBuffer
import TD3_BC
from Because.agent.TD3_BC.eval import evaluate_on_env
from robosuite import make
from robosuite import load_controller_config
from controller.stack_policy import StackPolicy
from controller.lift_policy import LiftPolicy, LiftCausalPolicy
from collector.gym_wrapper import GymStackWrapper, GymLiftWrapper, GymLiftCausalWrapper
from stable_baselines3.common.vec_env import SubprocVecEnv

WRAPPER = {
    "LiftCausal": GymLiftWrapper,
    "StackCausal": GymStackWrapper,
    "CausalPick": GymLiftCausalWrapper,
}

def make_envs(task='LiftCausal',horizon=30,control_freq=5,spurious_type='xnr',seed=100): 
    # print(seed)
    env = make(
        task,
        'Kinova3',
        horizon=horizon,
        control_freq=control_freq,
        has_renderer=False,
        has_offscreen_renderer=False,
        ignore_done=False,
        use_camera_obs=False,
        use_object_obs=True,
        controller_configs=load_controller_config(default_controller='OSC_POSITION'),
        spurious_type=spurious_type,
    )
    env = WRAPPER[task](env)
    env.seed(seed)
    return env
# Runs policy for X episodes and returns D4RL score
# A fixed seed is used for the eval environment
# def eval_policy(policy, env_name, seed, mean, std, seed_offset=100, eval_episodes=10):
# 	eval_env = gym.make(env_name)
# 	eval_env.seed(seed + seed_offset)

# 	avg_reward = 0.
# 	for _ in range(eval_episodes):
# 		state, done = eval_env.reset(), False
# 		while not done:
# 			state = (np.array(state).reshape(1,-1) - mean)/std
# 			action = policy.select_action(state)
# 			state, reward, done, _ = eval_env.step(action)
# 			avg_reward += reward

# 	avg_reward /= eval_episodes
# 	d4rl_score = eval_env.get_normalized_score(avg_reward) * 100

# 	print("---------------------------------------")
# 	print(f"Evaluation over {eval_episodes} episodes: {avg_reward:.3f}, D4RL score: {d4rl_score:.3f}")
# 	print("---------------------------------------")
# 	return d4rl_score




if __name__ == "__main__":
	
	parser = argparse.ArgumentParser()
	# Experiment
	parser.add_argument("--policy", default="TD3_BC")               # Policy name
	parser.add_argument("--env", default="lift")        # OpenAI gym environment name
	parser.add_argument("--seed", default=0, type=int)              # Sets Gym, PyTorch and Numpy seeds
	parser.add_argument("--eval_freq", default=5e3, type=int)       # How often (time steps) we evaluate
	parser.add_argument("--max_timesteps", default=1e6, type=int)   # Max time steps to run environment
	parser.add_argument("--save_model", action="store_true")        # Save model and optimizer parameters
	parser.add_argument("--load_model", default="")                 # Model load file name, "" doesn't load, "default" uses file_name
	# TD3
	parser.add_argument("--expl_noise", default=0.1)                # Std of Gaussian exploration noise
	parser.add_argument("--batch_size", default=256, type=int)      # Batch size for both actor and critic
	parser.add_argument("--discount", default=0.99)                 # Discount factor
	parser.add_argument("--tau", default=0.005)                     # Target network update rate
	parser.add_argument("--policy_noise", default=0.2)              # Noise added to target policy during critic update
	parser.add_argument("--noise_clip", default=0.5)                # Range to clip target policy noise
	parser.add_argument("--policy_freq", default=2, type=int)       # Frequency of delayed policy updateseva
	# TD3 + BC
	parser.add_argument("--alpha", type=float,default=1.0)
	parser.add_argument("--normalize", default=True)
	parser.add_argument("--num_eval_ep", type=int, default=20, help='number of evaluation episode length')
	parser.add_argument("--num_envs", type=int, default=16, help='number')
	parser.add_argument("--type", type=str, default='expert', help='expert/medium/random')
	parser.add_argument("--task", type=str, default='lift', help='lift/pick')
	args = parser.parse_args()

	file_name = f"{args.policy}_{args.env}_{args.type}_{args.seed}"
	print("---------------------------------------")
	print(f"Policy: {args.policy}, Env: {args.env}, Seed: {args.seed}")
	print("---------------------------------------")

	if not os.path.exists("./results"):
		os.makedirs("./results")

	if args.save_model and not os.path.exists("./models"):
		os.makedirs("./models")
  
	seed_list = np.random.choice(list(range(0, 10000000)), 16, replace=False)
	env = SubprocVecEnv([lambda i=i: make_envs(horizon=30,seed=seed_list[i]) for i in range(16)], start_method="spawn")

	# env = gym.make(args.env)

	# Set seeds
	# env.seed(args.seed)
	# env.action_space.seed(args.seed)
	torch.manual_seed(args.seed)
	np.random.seed(args.seed)
	best_succ_rate = -1
	best_value=-1
	state_dim = 33
	action_dim = 4
	replay_buffer = ReplayBuffer(state_dim, action_dim)
	max_action=float(replay_buffer.convert_D4RL(file_path="dataset/height01",type=args.type))

	kwargs = {
		"state_dim": state_dim,
		"action_dim": action_dim,
		"max_action": max_action,
		"discount": args.discount,
		"tau": args.tau,
		# TD3
		"policy_noise": args.policy_noise * max_action,
		"noise_clip": args.noise_clip * max_action,
		"policy_freq": args.policy_freq,
		# TD3 + BC
		"alpha": args.alpha
	}

	# Initialize policy
	policy = TD3_BC.TD3_BC(**kwargs)

	if args.load_model != "":
		policy_file = file_name if args.load_model == "default" else args.load_model
		policy.load(f"./models/{policy_file}")

	if args.normalize:
		mean,std = replay_buffer.normalize_states() 
	else:
		mean,std = 0,1
	
	evaluations = []
	for t in range(int(args.max_timesteps)):
		policy.train(replay_buffer, args.batch_size)
		# Evaluate episode
		if (t + 1) % args.eval_freq == 0:
			print(f"Time steps: {t+1}")
			results = evaluate_on_env(model=policy, env=env,mean=mean,std=std,num_eval_ep=args.num_eval_ep, num_envs=args.num_envs)
			eval_avg_reward = results['eval/avg_reward']
			eval_avg_ep_len = results['eval/avg_ep_len']
			eval_avg_succ = results['eval/success_rate']
			eval_avg_value=results['eval/avg_value']
			log_str = ("=" * 60 + '\n' +
            "eval avg reward: " + format(eval_avg_reward, ".5f") + '\n' + 
            "eval avg ep len: " + format(eval_avg_ep_len, ".5f") + '\n' +
            "eval avg succ: " + format(eval_avg_succ, ".5f") + '\n'+
            "eval avg value: " + format(eval_avg_value, ".5f") + '\n'
        )
			print(log_str)
			print(best_succ_rate,best_value)
			np.save(f"./results/{file_name}", evaluations)
			if eval_avg_succ >= best_succ_rate and eval_avg_value >= best_value:
				print('best model find!')
				best_succ_rate = eval_avg_succ
				best_value=eval_avg_value
				policy.save(f"./models/{file_name}")
				print('model saved!')
			# if args.save_model: policy.save(f"./models/{file_name}")
