import pickle
from rlpytorch.behavior_clone.coach_dataset import CoachDataset
import rlpytorch.behavior_clone.global_consts as gc
import numpy as np


def get_dictionary_stats(inst_dict):
    unique_inst = len(inst_dict._idx2inst)
    print('# unique inst:', unique_inst)
    num_words = []
    unique_words = set()
    for inst in inst_dict._idx2inst:
        words = inst.split()
        num_words.append(len(words))
        for w in words:
            unique_words.add(w)

    print('num_words:', np.sum(num_words))
    print('avg_num_words:', np.mean(num_words))
    print('num_unique_words:', len(unique_words))


def get_avg_num_instruction(data):
    i = 0
    nums = []
    while i < len(data):
        replay_name = data[i]['unique_id'].rsplit('-', 1)[0]
        inst = data[i]['instruction']
        num_inst = 1
        j = i
        while j < len(data) and replay_name == data[j]['unique_id'].rsplit('-', 1)[0]:
            if data[j]['instruction'] != inst:
                inst = data[j]['instruction']
                num_inst += 1
            j += 1
        i = j
        nums.append(num_inst)

    print('avg_num_inst_per_game:', np.mean(nums))


def get_avg_num_cmd_per_inst(data):
    i = 0
    nums = []
    while i < len(data):
        inst = data[i]['instruction']
        j = i
        num_cmds = 0
        while j < len(data) and data[j]['instruction'] == inst:
            units = data[j]['my_units']
            for u in units:
                target_cmd = u['target_cmd']
                cmd_type = target_cmd['cmd_type']
                if cmd_type != gc.CmdTypes.IDLE.value and cmd_type != gc.CmdTypes.CONT.value:
                    # print(cmd_type)
                    num_cmds += 1
            j += 1
        i = j
        nums.append(num_cmds)

    print('avg_num_cmds_per_inst:', np.mean(nums))


if __name__ == '__main__':

    inst_dict = pickle.load(open('./data3/train.json_min10_dict.pt', 'rb'))
    inst_dict.set_max_sentence_length(20)

    get_dictionary_stats(inst_dict)

    # train = CoachDataset(
    #     './data3/train.json_min10',
    #     0.9,
    #     11,
    #     50,
    #     10,
    #     inst_dict,
    #     2000,
    # )
    # valid = CoachDataset(
    #     './data3/valid.json_min10',
    #     0.9,
    #     11,
    #     50,
    #     10,
    #     inst_dict,
    #     2000,
    # )
    # data = train.data + valid.data
