import json
import pprint


ref_path = '/private/home/hengyuan/elf2-bc-dev/src_py/rlpytorch/behavior_clone/data3/valid.json_min10'
path = '/scratch/hengyuan/rts_data/dataset_ref2/val.json'

# ref_path = '/scratch/hengyuan/minirts/data/valid.json_min10'
# path = '/scratch/hengyuan/rts_data/dataset_ref/val.json'


ref_dataset = json.load(open(ref_path))
dataset = json.load(open(path))#[:1000]

if len(ref_dataset) != len(dataset):
    print('len mismatch:', len(ref_dataset), len(dataset))


def compare_my_unit(ref_u, u):
    for key in ref_u.keys():
        ref_val = ref_u[key]
        if key == 'x' or key == 'y':
            val = int(u[key])
        elif key == 'pre_ins_prev_cmd':
            val = u['instructor_prev_cmd']
        elif key == 'prev_cmd':
            val = u['executor_prev_cmd']
        elif key == 'current_cmd' or key == 'target_cmd':
            val = u[key]
            val['target_x'] = int(val['target_x'])
            val['target_y'] = int(val['target_y'])
        elif key == 'current_cmd_cont':
            continue
        else:
            val = u[key]

        if ref_val != val:
            return key, ref_val, val

    return None, None, None


def compare_datapoint(d_ref, d, i):
    assert len(d_ref.keys()) == len(d.keys())
    for key in d_ref:
        if key in {'cons_count',
                   'instruction',
                   'map',
                   'resource',
                   'tick',
                   'glob_cont',
                   'prev_instruction'}:
            if d_ref[key] != d[key]:
                # if key == 'prev_instruction':
                #     # need to remove prev instruction in
                #     # filter_beginning function
                #     continue
                print('error:', i, key, 'val: ', d_ref[key], d[key])
        elif key == 'pre_ins_base_frame_idx':
            if d_ref[key] != d['coach_base_frame_idx']:
                print('error:', i, key, d_ref[key], d['coach_base_frame_idx'])
        elif key == 'base_frame_idx':
            if d_ref[key] != d['executor_base_frame_idx']:
                print('error:', i, key, d_ref[key], d['executor_base_frame_idx'])
        elif key == 'enemy_units' or key == 'resource_units':
            v_ref = d_ref[key]
            v = d[key]
            for u in v:
                for k in u:
                    if k == 'x' or k == 'y':
                        u[k] = int(u[k])
            if v_ref != v:
                print('error:', i, key)
                pprint.pprint(v_ref)
                pprint.pprint(v)
                print('---')
                assert False
        elif key == 'my_units':
            ref_units = d_ref[key]
            units = d[key]
            if len(ref_units) != len(units):
                print('error:', i, key, 'len mismatch', len(ref_units), len(units))
                continue
            for uid, (ref_u, u) in enumerate(zip(ref_units, units)):
                key, ref_val, val = compare_my_unit(ref_u, u)
                if key is not None:
                    print('error:', i, ', my unit[%d]:' % uid, key)
                    pprint.pprint(ref_val)
                    pprint.pprint(val)
                    print('---')
                    assert False


replays = []
for i, (dref, d) in enumerate(zip(ref_dataset, dataset)):
    replay = dref['unique_id'].rsplit('-', 1)[0]

    if len(replays) == 0 or replays[-1] != replay:
        print('new replay', i)
        replays.append(replay)

    compare_datapoint(dref, d, i)

print('pass')
