import os
import sys
import pickle
import pprint
from collections import defaultdict
import time

import numpy as np
import torch
from torch import autograd
from torch.utils.data import DataLoader

from elf.options import PyOptionSpec

from rlpytorch.behavior_clone.coach_dataset import CoachDataset, merge_max_units
from rlpytorch.behavior_clone import utils
from rlpytorch.behavior_clone.one_hot_generator import OneHotGenerator
from rlpytorch.utils import set_all_seeds, EvalMode, Logger
from rlpytorch.behavior_clone.instruction_encoder import is_word_based


def parse_args():
    spec = PyOptionSpec()

    # train config
    spec.addIntOption('gpu', '', 0)
    spec.addIntOption('seed', '', 1)
    spec.addFloatOption('temperature', '', 1)
    spec.addStrOption(
        'val_dataset',
        'path to val dataset',
        'data/human-valid-build-stop.json')
    spec.addStrOption('model_file', 'saved model file', 'model-test')
    spec.addBoolOption('dev', 'for debug', False)

    # merge with other class's options
    spec.merge(OneHotGenerator.get_option_spec())
    option_map = spec.parse()
    return option_map


def main():
    torch.backends.cudnn.benchmark = True

    option_map = parse_args()
    options = option_map.getOptions()

    print('Args:\n%s\n' % pprint.pformat(vars(options)))

    if options.gpu < 0:
        device = torch.device('cpu')
    else:
        device = torch.device('cuda:%d' % options.gpu)

    set_all_seeds(options.seed)

    model = utils.load_model(options.model_file).to(device)

    val_dataset = CoachDataset(
        options.val_dataset,
        options.num_resource_bin,
        options.resource_bin_size,
        options.max_num_prev_cmds,
        inst_dict=model.inst_dict,
        word_based=True,
        num_instructions=options.num_instructions)
    val_loader = DataLoader(
        val_dataset,
        5,
        shuffle=False,
        num_workers=0,
        pin_memory=(options.gpu >= 0))

    model.eval()
    pad = model.inst_dict.get_words_pad()
    for i, batch in enumerate(val_loader):
        if i > 300:
            break
        utils.to_device(batch, device)
        prev_inst = batch['current']['prev_instruction'].cpu().numpy()
        prev_inst = ' '.join(
            model.inst_dict.idx2word[w] for w in prev_inst[0] if w != pad)

        true_inst = batch['current']['instruction_target'].cpu().numpy()
        true_inst = ' '.join(
            model.inst_dict.idx2word[w] for w in true_inst[0] if w != pad)

        if prev_inst != true_inst:
            print('=====================')
        print(i)

        sampled_inst, sampled_len = model.sample(batch, options.temperature)
        sampled_inst = sampled_inst.cpu().numpy()
        sampled_inst = ' '.join(
            model.inst_dict.idx2word[w] for w in sampled_inst[0] if w != pad)

        print('PREV: %s' % prev_inst)
        print('SAMPLED: %s' % sampled_inst)
        print('TRUE: %s' % true_inst)
        print('')


if __name__ == '__main__':
    main()
