from rsa.algos.awac import AWAC
import rsa.utils as utils
import rsa.utils.pytorch_utils as ptu
from rsa.utils.arg_parser import parse_args
from rsa.utils.logx import EpochLogger
import rsa.utils.spb_utils as spbu

import numpy as np
import gym
from gym.wrappers import Monitor
import os
import json
from tqdm import trange

if __name__ == "__main__":
    params = parse_args(awac_args=True)

    utils.seed(params['seed'])
    logdir = params['logdir']
    os.makedirs(logdir)
    os.makedirs(os.path.join(logdir, 'misc'))
    ptu.setup(params['device'])
    with open(os.path.join(logdir, 'hparams.json'), 'w') as f:
        json.dump(params, f)

    env, test_env = utils.make_env(params)
    is_pointbot_env = params['env'] in ('spb', 'rpb', 'lpb', 'hpb', 'mpb', 'lpb_easy')

    logger = EpochLogger(output_dir=logdir, exp_name=params['exper_name'])
    loss_plotter = utils.LossPlotter(os.path.join(logdir, 'loss_plots'))

    # rsa = TD3((17,), (6,), 1)
    awac = AWAC(params)

    if params['env'] in utils.d4rl_envs:
        replay_buffer = utils.load_d4rl_replay_buffer(env, params, add_drtg=True)
    else:
        replay_buffer = utils.load_replay_buffer(params, add_drtg=True)

    if params['checkpoint'] is not None:
        awac.load(params['checkpoint'])
    else:
        print('Pretraining Policy')
        os.makedirs(os.path.join(logdir, 'pretrain_plots'))
        for i in trange(params['init_iters']):
            info = awac.update(replay_buffer)
            loss_plotter.add_data(**info)

        if params['init_iters'] > 0:
            awac.save(os.path.join(logdir, 'pretrain'))
            loss_plotter.plot()

    # Run training loop
    # Prepare for interaction with environment
    i = 0
    n_episodes = 0
    epoch = 0
    metrics = {
        'Timesteps': 0,
    }
    robosuite = params['env'] in ('Lift', 'Door', 'NutAssembly', 'TwoArmPegInHole')

    total_timesteps = params['total_timesteps']

    while i < total_timesteps:
        # Collect one trajectory
        obs, done, t = env.reset(), False, 0
        ep_buf, rets = [], []
        while not done and t < params['horizon']:
            ################################################################################
            # Every params['eval_freq'] timesteps, run the evaluation loop and output logs #
            ################################################################################
            if i % params['eval_freq'] == 0:

                print('Testing Agent')
                for j in range(params['num_eval_episodes']):
                    obs, done, ep_ret, ep_len = test_env.reset(), False, 0, 0
                    while not done:
                        # Take deterministic actions at test time (noise_scale=0)
                        act = awac.select_action(obs, evaluate=True)
                        next_obs, rew, done, info = test_env.step(act)
                        ep_ret += rew
                        ep_len += 1
                        obs = next_obs
                    if robosuite:
                        test_env.close()
                    logger.store(TestEpRet=ep_ret, TestEpLen=ep_len)

                # Log info about epoch
                logger.log_tabular('Epoch', epoch)
                logger.log_tabular('TotalEnvInteracts', i)
                logger.log_tabular('TestEpRet')
                logger.log_tabular('TestEpLen', average_only=True)
                if epoch == 0:
                    logger.log_tabular('AverageTrainEpRet', 0)
                    logger.log_tabular('StdTrainEpRet', 0)
                    logger.log_tabular('TrainEpLen', 0)
                    logger.log_tabular('Q1', 0)
                    logger.log_tabular('Q2', 0)
                else:
                    logger.log_tabular('TrainEpRet')
                    logger.log_tabular('TrainEpLen', average_only=True)
                    logger.log_tabular('Q1', average_only=True)
                    logger.log_tabular('Q2', average_only=True)
                for metric, value in metrics.items():
                    logger.log_tabular(metric, value)
                logger.dump_tabular()

                epoch += 1
                loss_plotter.plot()
                awac.save(os.path.join(logdir, 'models'))

                # if is_pointbot_env:
                #     spbu.plot_Q(awac, env,
                #                 os.path.join(logdir, 'misc', 'q_%d.pdf' % i),
                #                 skip=2)
                # if params['plot_drtg_maxes']:
                #     spbu.plot_maxes(awac, env,
                #                     os.path.join(logdir, 'misc', 'q_maxes_%d.pdf' % n_episodes))
                #     awac.drtg_buffer = set()
                #     awac.bellman_buffer = set()

            ########################
            # Begin policy updates #
            ########################

            if i < params['start_timesteps']:
                act = env.action_space.sample()
                a_expert = None
            else:
                act = awac.select_action(obs)

            next_obs, rew, done, info = env.step(act)
            ep_buf.append({
                'obs': obs,
                'next_obs': next_obs,
                'act': act,
                'rew': utils.shift_reward(rew, params),
                'done': done,
                'expert': 0,
                'goal': info['goal'] if 'goal' in info else 0,
                'mask': info['mask'] if 'mask' in info
                else (1 if t == params['horizon'] else float(not done))

            })
            obs = next_obs

            i += 1
            t += 1
            rets.append(rew)
            metrics['Timesteps'] += 1

            # grad steps
            if i >= params['start_timesteps'] and i % params['update_freq'] == 0:
                for _ in range(params['update_n_steps']):
                    if len(replay_buffer) == 0:
                        break
                    info = awac.update(replay_buffer)
                    logger.store(**info)
                    loss_plotter.add_data(**info)

        x, succ = 0, 0
        for j, transition in enumerate(reversed(ep_buf)):
            # TODO We need to come up with a good way to estimate this for general environments.
            #   For the goal conditioned method it's easy to say the rest of the rewards will
            #   always be -1 or 0. However, for general environments this is not the case.
            #   Possible options I've considered are assuming it will always be minimum reward,
            #   mean reward or median rewar, or last reward.
            #   -
            #   For now I'm implementing last reward
            if j == 0:
                succ = succ or transition['goal']
                if not transition['mask']:
                    x = transition['rew']
                else:
                    # Set drtg to infinite discounted reward sum.
                    # reward_estimate = np.median(rets)
                    reward_estimate = ep_buf[-1]['rew']
                    if params['discount'] < 1:
                        x = reward_estimate / (1 - params['discount'])
                    else:
                        x = reward_estimate * float('inf')
            else:
                x = transition['rew'] + transition['mask'] * params['discount'] * x
            # print(x, transition['rew'])
            transition['drtg'] = x
            transition['succ'] = succ
            del transition['goal']
            replay_buffer.store_transition(transition)

        if robosuite:
            env.close()

        logger.store(TrainEpRet=sum(rets), TrainEpLen=len(rets))
        n_episodes += 1
