import argparse
import os
import sys
from collections import defaultdict
import pprint
import copy

import torch
from torch import nn
from torch.nn.utils import weight_norm
import numpy as np

# path will be set in create_envs... or we can call set_path
from create_envs import create_default_env, create_eval_env
from executor_based_model import CoachExecutorModel

# import tube
import minirts
from pyxrl.data_channel_manager import DataChannelManager

from actor_critic import ActorCritic
from rule_ai_sampler import RuleAISampler
from coach_sampler import CoachSampler
import common_utils

import global_consts as gc
from behavior_clone.rnn_coach import ConvRnnCoach
from behavior_clone.executor import Executor
from behavior_clone.set_path import best_rnn_executor, best_rnn_coach, best_rnn_coach_nofow


def parse_args():
    parser = argparse.ArgumentParser(description='rl coach')

    parser.add_argument('--save_dir', type=str, default='dev/coach_dev')
    parser.add_argument('--seed', type=int, default=1)

    # parser.add_argument('--deterministic', action='store_true')
    parser.add_argument('--num_thread', type=int, default=32)
    parser.add_argument('--batchsize', type=int, default=16)
    parser.add_argument('--gpu', type=int, default=0)
    parser.add_argument('--update_per_epoch', type=int, default=200)
    parser.add_argument('--num_epoch', type=int, default=200)

    root = os.path.dirname(
        os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
    default_lua = os.path.join(root, 'game/game_MC/lua')
    parser.add_argument('--lua_files', type=str, default=default_lua)

    # optim
    parser.add_argument('--lr', type=float, default=6.25e-5)
    parser.add_argument('--eps', type=float, default=1.5e-4)
    parser.add_argument('--grad_clip', type=float, default=0.25)

    # actor crtic option
    parser.add_argument('--sync_each_step', type=int, default=0)
    parser.add_argument('--central_val', type=int, default=0)
    parser.add_argument('--ppo', type=int, default=0)
    parser.add_argument('--replay_buffer', type=int, default=2048)
    parser.add_argument('--ent_ratio', type=float, default=1e-2)
    parser.add_argument('--min_ent_ratio', type=float, default=1e-4)
    parser.add_argument('--min_prob', type=float, default=1e-6)
    parser.add_argument('--max_importance_ratio', type=float, default=1.2)
    parser.add_argument('--ratio_clamp', type=float, default=0.1)
    parser.add_argument('--gamma', type=float, default=0.997)

    # training method related
    parser.add_argument('--cheat', type=int, default=0)
    parser.add_argument('--use_xent', type=int, default=0)

    # instruction set
    parser.add_argument('--inst_mode', type=str, default='full') # can be full/good/better

    # ai1 option
    parser.add_argument('--frame_skip', type=int, default=50)
    parser.add_argument('--fow', type=int, default=1)
    parser.add_argument('--t_len', type=int, default=10)
    parser.add_argument('--use_moving_avg', type=int, default=1)
    parser.add_argument('--moving_avg_decay', type=float, default=0.98)
    parser.add_argument('--num_resource_bins', type=int, default=11)
    parser.add_argument('--resource_bin_size', type=int, default=50)
    parser.add_argument('--max_num_units', type=int, default=50)
    parser.add_argument('--num_prev_cmds', type=int, default=25)
    # TOOD: add max instruction span
    parser.add_argument('--max_raw_chars', type=int, default=200)
    parser.add_argument('--verbose', action='store_true')

    # handicap
    parser.add_argument('--adversarial', type=int, default=0)
    parser.add_argument('--win_rate_decay', type=float, default=0.95)
    parser.add_argument('--min_resource_scale', type=float, default=1.0)
    parser.add_argument('--max_resource_scale', type=float, default=1.0)
    parser.add_argument('--num_extra_units', type=int, default=0)

    # game option
    parser.add_argument('--max_tick', type=int, default=int(5e4))
    parser.add_argument('--no_terrain', action='store_true')
    parser.add_argument('--resource', type=int, default=500)
    parser.add_argument('--resource_dist', type=int, default=4)
    parser.add_argument('--fair', type=int, default=0)
    parser.add_argument('--save_replay_freq', type=int, default=0)
    parser.add_argument('--save_replay_per_games', type=int, default=1)

    # model
    parser.add_argument('--coach_type', type=str, default='rnn')
    parser.add_argument('--coach_path', type=str, default='')
    parser.add_argument('--executor_type', type=str, default='rnn')
    parser.add_argument('--executor_path', type=str, default='')

    args = parser.parse_args()
    if args.executor_type == 'bow':
        args.executor_path = best_bow_executor
    elif args.executor_type == 'rnn':
        args.executor_path = best_rnn_executor

    if args.coach_type == 'rnn_nofow':
        args.coach_path = best_rnn_coach_nofow
    elif args.coach_type == 'rnn':
        args.coach_path = best_rnn_coach
    return args


def get_game_option(args):
    game_option = minirts.RTSGameOption()
    game_option.seed = args.seed
    game_option.max_tick = args.max_tick
    game_option.no_terrain = args.no_terrain
    game_option.resource = args.resource
    game_option.resource_dist = args.resource_dist
    game_option.fair = args.fair
    game_option.save_replay_freq = args.save_replay_freq
    game_option.save_replay_per_games = args.save_replay_per_games
    game_option.lua_files = args.lua_files
    # game_option.num_games_per_thread = 1
    # !!! this is important
    game_option.max_num_units_per_player = args.max_num_units
    game_option.num_extra_units = args.num_extra_units

    save_dir = os.path.join(os.path.abspath(args.save_dir), 'replay')
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    game_option.save_replay_prefix = save_dir + '/'

    return game_option


def get_ai_options(args, num_instructions):
    ai1_option = minirts.AIOption()
    ai1_option.coach_type = 'medium_ai'
    ai1_option.t_len = args.t_len
    ai1_option.fs = args.frame_skip
    ai1_option.fow = args.fow
    ai1_option.use_moving_avg = args.use_moving_avg
    ai1_option.moving_avg_decay = args.moving_avg_decay
    ai1_option.num_resource_bins = args.num_resource_bins
    ai1_option.resource_bin_size = args.resource_bin_size
    ai1_option.max_num_units = args.max_num_units
    ai1_option.num_prev_cmds = args.num_prev_cmds
    ai1_option.num_instructions = num_instructions
    ai1_option.max_raw_chars = args.max_raw_chars
    ai1_option.verbose = args.verbose

    ai2_option = minirts.AIOption()
    ai2_option.fs = args.frame_skip
    ai2_option.fow = args.fow

    return ai1_option, ai2_option


def train(epoch,
          num_update,
          dc_manager,
          method,
          optim,
          model,
          critic_optim,
          critic_model,
          device,
          coach_sampler,
          rule_sampler,
          stat,
          result_stat,
          critic_only):
    import time
    timer = defaultdict(float)

    while num_update > 0:
        t = time.time()
        batches = dc_manager.get_input()
        timer['get input'] += time.time() - t
        for key, batch in batches.items():
            assert key in dc_manager
            stat.inc(key)

            if key == 'act1' or key == 'act2':
                t = time.time()
                batch = common_utils.to_device(batch, device)
                with torch.no_grad():
                    reply = model.act(batch)
                result_stat.feed(batch)
                dc_manager.set_reply(key, reply)
                timer['act'] += time.time() - t
            elif key == 'train':
                t = time.time()
                batch = common_utils.to_device(batch, device)

                optim.assert_zero_grad()
                if critic_optim is not None:
                    critic_optim.assert_zero_grad()

                method.compute_gradient(model, critic_model, batch, stat)

                if critic_only:
                    critic_optim.step(None)
                    optim.zero_grad()
                else:
                    optim.step(stat)
                    if critic_optim is not None:
                        critic_optim.step(None)

                dc_manager.set_reply(key, {})
                num_update -= 1
                timer['train'] += time.time() - t
            elif key == 'rule':
                reply = rule_sampler.feed(batch)
                dc_manager.set_reply(key, reply)
            else:
                assert False

    print(timer)


def evaluate(mode, model, device, num_game, game_option, ai_option, epoch, result_stat, save_dir):
    save_dir = os.path.join(os.path.abspath(save_dir), 'eval_epoch%d' % epoch)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    game_option.save_replay_prefix = save_dir + '/'

    ai2_option = minirts.AIOption()
    ai2_option.fow = 1
    ai2_option.fs = ai_option.fs

    context, act_dc = create_eval_env(
        num_game,
        ai_option,
        ai2_option,
        game_option,
        mode
    )
    context.start()
    dc = DataChannelManager([act_dc])

    while not context.terminated():
        data = dc.get_input(max_timeout_s=1)
        if len(data) == 0:
            continue
        data = common_utils.to_device(data['act'], device)
        with torch.no_grad():
            reply = model.act(data)
        result_stat.feed(data)
        dc.set_reply('act', reply)

    print(result_stat.log(epoch))
    result_stat.reset()
    dc.terminate()


def load_model(coach_path, max_raw_chars, executor_path, device, cheat, inst_mode):
    coach = ConvRnnCoach.load(coach_path)
    coach.max_raw_chars = max_raw_chars
    coach = coach.to(device)
    executor = Executor.load(executor_path).to(device)
    model = CoachExecutorModel(coach, executor, cheat, inst_mode)
    return model


def turn_off_dropout(model):
    model.coach.glob_dropout.train(False)
    model.coach.prev_inst_encoder.emb_dropout.train(False)


if __name__ == '__main__':
    args = parse_args()
    print('args:')
    pprint.pprint(vars(args))

    os.environ['LUA_PATH'] = os.path.join(args.lua_files, '?.lua')
    print('lua path:', os.environ['LUA_PATH'])

    torch.backends.cudnn.benchmark = True

    if args.save_dir:
        logger_path = os.path.join(args.save_dir, 'train.log')
        sys.stdout = common_utils.Logger(logger_path)

    device = torch.device('cuda:%d' % args.gpu)
    model = load_model(
        args.coach_path, args.max_raw_chars, args.executor_path, device, args.cheat, args.inst_mode)
    turn_off_dropout(model)

    # critic = load_model(args.coach_path, args.max_raw_chars, args.executor_path, device, True)
    # critic.coach.glob_dropout.train(False)
    # critic.executor.train(False)

    # add ref net for xent
    ref_net = load_model(
        args.coach_path, args.max_raw_chars, args.executor_path, device, args.cheat, args.inst_mode)
    ref_net.train(False)
    model.sampler.add_ref_net(ref_net)

    optim = common_utils.Optim(
        model.coach, torch.optim.Adam, {'lr': args.lr, 'eps': args.eps}, args.grad_clip)
    # shared critic
    critic_optim = common_utils.Optim(
        model.coach.value, torch.optim.Adam, {'lr': args.lr, 'eps': args.eps}, args.grad_clip)

    method = ActorCritic(
        ent_ratio=args.ent_ratio,
        min_prob=args.min_prob,
        max_importance_ratio=args.max_importance_ratio,
        ratio_clamp=args.ratio_clamp,
        gamma=args.gamma,
        ppo=args.ppo,
        use_xent=args.use_xent)

    ai1_option, ai2_option = get_ai_options(args, model.coach.num_instructions)
    game_option = get_game_option(args)
    # eval_option = minirts.RTSGameOption(game_option)
    # eval_option.seed = 1 #999111
    # eval_option.num_games_per_thread = 1

    config, context, train_dc, act_dc1, act_dc2, rule_dc, games = create_default_env(
        args.num_thread,
        args.batchsize,
        args.seed,
        ai1_option,
        ai2_option,
        game_option,
        mode='executor'
    )
    dc_manager = DataChannelManager([train_dc, act_dc1, act_dc2, rule_dc])
    context.start()

    rule_sampler = RuleAISampler(args.adversarial,
                                 args.min_resource_scale,
                                 args.max_resource_scale,
                                 args.win_rate_decay)

    stat = common_utils.MultiCounter(args.save_dir)
    train_result = common_utils.ResultStat(
        'reward', os.path.join(args.save_dir, 'train_win'))
    eval_result = common_utils.ResultStat(
        'reward', os.path.join(args.save_dir, 'eval_win'))

    args.update_per_epoch = int(args.update_per_epoch / (args.t_len / 10))

    # evaluate('executor',
    #          model,
    #          device,
    #          100,
    #          eval_option,
    #          ai1_option,
    #          0,
    #          eval_result,
    #          args.save_dir)

    # # # warm up critic
    print('warm up critic')
    for epoch in range(1):
        train(epoch,
              args.update_per_epoch,
              dc_manager,
              method,
              optim,
              model,
              critic_optim,
              None, #critic,
              device,
              None,
              rule_sampler,
              stat,
              train_result,
              True)

        model_file = os.path.join(args.save_dir, 'model%d.pt' % epoch)
        print('saving model to:', model_file)
        model.coach.save(model_file)

        stat.summary(epoch)
        print(train_result.log(epoch))
        rule_sampler.log()

        stat.reset()
        train_result.reset()
        rule_sampler.reset()

    optim.zero_grad()

    for epoch in range(1, args.num_epoch):
        # method.ent_ratio = max(
        #     args.min_ent_ratio, args.ent_ratio - args.ent_ratio / 200 * epoch)

        # if epoch % 5 == 0:
        #     print('=============eval==============')
        #     evaluate('executor',
        #              model,
        #              device,
        #              500,
        #              eval_option,
        #              ai1_option,
        #              epoch,
        #              eval_result,
        #              args.save_dir)
        #     print('========= end of eval==========')

        train(epoch,
              args.update_per_epoch,
              dc_manager,
              method,
              optim,
              model,
              None, # critic_optim,
              None, # critic,
              device,
              None,
              rule_sampler,
              stat,
              train_result,
              False)

        model_file = os.path.join(args.save_dir, 'model%d.pt' % epoch)
        print('saving model to:', model_file)
        model.coach.save(model_file)

        stat.summary(epoch)
        print(train_result.log(epoch))
        rule_sampler.log()

        stat.reset()
        train_result.reset()
        rule_sampler.reset()
