import pprint

from set_path import append_sys_path
append_sys_path()

import torch
import tube
import minirts


def create_default_env(
        num_thread,
        batchsize,
        meta_seed,
        ai1_option,
        ai2_option,
        game_option,
        *,
        mode='rule',
        train_name='train',
        act_name='act',
        comm_name = 'comm',
        rule_name='rule'):

    config = locals()
    print('creating evns with config:')
    pprint.pprint(config)
    print('ai1 option:')
    print(ai1_option.info())
    print('ai2 option:')
    print(ai2_option.info())
    print('game option:')
    print(game_option.info())

    train_dc = tube.DataChannel(train_name, batchsize, -1)
    act_dc1 = tube.DataChannel(act_name+'1', batchsize, -1)
    act_dc2 = tube.DataChannel(act_name+'2', batchsize, -1)
    rule_dc = tube.DataChannel(rule_name, batchsize, 1)
    if mode == 'mix':
        comm_dc = tube.DataChannel(comm_name, batchsize, 1)

    context = tube.Context()
    games = []
    # bots = []

    for thread_idx in range(num_thread):
        if thread_idx < num_thread // 2:
            act_dc = act_dc1
        else:
            act_dc = act_dc2
        g_option = minirts.RTSGameOption(game_option)
        g_option.seed = meta_seed + thread_idx
        if thread_idx > 10:
            g_option.save_replay_prefix = ''

        if g_option.save_replay_prefix:
            g_option.save_replay_prefix = game_option.save_replay_prefix + str(thread_idx)
            # print('prefix:', g_option.save_replay_prefix)

        g = minirts.RTSGame(g_option)
        if mode == 'rule':
            bot1 = minirts.TrainableRuleAI(ai1_option, thread_idx, train_dc, act_dc)
        elif mode == 'executor':
            bot1 = minirts.CheatExecutorAI(ai1_option, thread_idx, train_dc, act_dc)
        elif mode == 'mix':
            bot1 = minirts.MixExecutorAI(ai1_option, thread_idx, train_dc, act_dc, comm_dc)
        else:
            assert False, 'unknown mode: %s' % mode

        bot2 = minirts.MediumAI(
            ai2_option, thread_idx, rule_dc, minirts.UnitType.INVALID_UNITTYPE, False)
        g.add_bot(bot1)
        g.add_bot(bot2)
        games.append(g)
        # bots.append(bot1)
        # bots.append(bot2)
        context.push_env_thread(g)

    if mode == 'mix':
        return config, context, train_dc, act_dc, comm_dc, rule_dc, games#, bots

    return config, context, train_dc, act_dc1, act_dc2, rule_dc, games#, bots


def create_eval_env(num_games, ai1_option, ai2_option, game_option, mode, *, act_name='act'):
    print('ai1 option:')
    print(ai1_option.info())
    print('ai2 option:')
    print(ai2_option.info())
    print('game option:')
    print(game_option.info())

    batchsize = min(32, max(num_games // 2, 1))
    act_dc = tube.DataChannel(act_name, batchsize, 1)
    context = tube.Context()
    idx2utype = [
        minirts.UnitType.SPEARMAN,
        minirts.UnitType.SWORDMAN,
        minirts.UnitType.CAVALRY,
        minirts.UnitType.DRAGON,
        minirts.UnitType.ARCHER,
    ]

    for i in range(num_games):
        g_option = minirts.RTSGameOption(game_option)
        g_option.seed = game_option.seed + i
        if i > 20:
            g_option.save_replay_prefix = ''

        if g_option.save_replay_prefix:
            g_option.save_replay_prefix = game_option.save_replay_prefix + str(i)

        g = minirts.RTSGame(g_option)

        if mode == 'rule':
            bot1 = minirts.TrainableRuleAI(ai1_option, i, None, act_dc)
        elif mode == 'executor':
            bot1 = minirts.CheatExecutorAI(ai1_option, i, None, act_dc)
        else:
            assert False

        utype = idx2utype[i % len(idx2utype)]
        bot2 = minirts.MediumAI(ai2_option, i, None, utype, False)
        g.add_bot(bot1)
        g.add_bot(bot2)
        context.push_env_thread(g)

    return context, act_dc
