import numpy as np
import os
import torch
from tqdm import trange, tqdm
from utils.utils import CUDA
from Because.agent.icil.icil_state import ICIL as ICIL_state
from stable_baselines3.common.vec_env import SubprocVecEnv
import sys


from robosuite import make
from robosuite import load_controller_config
from collector.gym_wrapper import GymStackWrapper, GymLiftWrapper, GymLiftCausalWrapper
WRAPPER = {
    "LiftCausal": GymLiftWrapper,
    "StackCausal": GymStackWrapper,
    "CausalPick": GymLiftCausalWrapper,
}

def evaluate_on_env(model, device, env, num_eval_ep=20,num_envs=16):

    eval_batch_size = num_envs  # required for forward pass
    envs_timestep=np.zeros(num_envs)

    results = {}
    total_reward = 0
    total_timesteps = 0
    total_succ = 0
    total_value=0
    discount=0.99

    model.eval()
    count_done = 0
    pbar = tqdm(total=num_eval_ep)
    timestep_last = np.zeros(eval_batch_size, dtype=np.int32)
    t = 0
    with torch.no_grad():
        running_state = env.reset()
        while count_done < num_eval_ep:
            # total_timesteps += eval_batch_size
            envs_timestep+=1

            # Modify the observation to fit your model's input requirements
            obs = running_state  # You may need to preprocess obs here
            obs=torch.from_numpy(obs).float()
            
            # Forward pass through your model to get actions
            pri_state = CUDA(obs)
            act_logits = model.policy_network(model.causal_feature_encoder(pri_state))
            act = act_logits[:,:4].detach().cpu().numpy()
            

            # Take a step in the environment with the selected actions
            running_state, running_reward, done, info = env.step(act)
            # Extract relevant information from the environment
            # total_reward += np.sum(running_reward)
            for i in range(len(done)):
                if done[i]:
                    if info[i]['success']:
                        total_reward+=1
                        total_succ +=1
                        time=envs_timestep[i]
                        value=discount**time
                        total_value+=value
                    count_done += 1
                    total_timesteps+=envs_timestep[i]
                    envs_timestep[i]=0
                    pbar.update(1)
                if count_done >= num_eval_ep:
                    break
            t += 1

    pbar.close()
    results['eval/avg_reward'] = total_reward / count_done
    results['eval/avg_ep_len'] = total_timesteps / count_done
    results['eval/success_rate'] = total_succ / count_done
    results['eval/avg_value']=total_value / count_done
    print(total_succ / count_done)
    print(total_value/count_done)

    return results

def make_envs(seed,task='LiftCausal', horizon=70,control_freq=5,spurious_type='xpr'): 
    env = make(
        task,
        'Kinova3',
        horizon=horizon,
        control_freq=control_freq,
        has_renderer=False,
        has_offscreen_renderer=False,
        ignore_done=False,
        use_camera_obs=False,
        use_object_obs=True,
        controller_configs=load_controller_config(default_controller='OSC_POSITION'),
        # spurious_type='xnr',
        spurious_type=spurious_type,
    )
    env = WRAPPER[task](env)
    env.seed(seed)
    return env

def make_pickenvs(task='CausalPick',horizon=70,control_freq=5,seed=100): 
    # print(seed)
    env = make(
        task,
        'Kinova3',
        horizon=horizon,
        control_freq=control_freq,
        has_renderer=False,
        has_offscreen_renderer=False,
        ignore_done=False,
        use_camera_obs=False,
        use_object_obs=True,
        controller_configs=load_controller_config(default_controller='OSC_POSITION'),
        num_unmovable_objects=1,
        num_random_objects=0,
        num_markers=3,
    )
    env = WRAPPER[task](env,mode="train")
    env.seed(seed)
    return env

if __name__ == '__main__':
    success_rate=np.zeros(5)
    torch.manual_seed(0)
    np.random.seed(0)
    for i in range(5):
        print(f"itr_{i}")
        model=CUDA(ICIL_state(state_dim=33, action_dim=4, hidden_dim_input=256, hidden_dim=256))
        checkpoint_path=os.path.join("checkpoint",f"icil_expert_xnr_40_60_{i+1}",'hest.pt')
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint)
        seed_list = np.random.choice(list(range(0, 10000000)), 16, replace=False)
        env = SubprocVecEnv([lambda i=i: make_envs(seed=seed_list[i],horizon=30,spurious_type='xpr') for i in range(16)], start_method="spawn")
        results=evaluate_on_env(model, torch.device('cuda:0'), env, num_eval_ep=100,num_envs=16)
        success_rate[i]=results['eval/success_rate']
    mean=np.mean(success_rate)
    std=np.std(success_rate)
    print(mean,std)
    data_without_extremes = np.delete(success_rate, [np.argmax(success_rate), np.argmin(success_rate)])
    mean_we=np.mean(data_without_extremes)
    std_we=np.std(data_without_extremes)
    print(mean_we,std_we)

