import os
import json
import tempfile
from collections import defaultdict
from copy import deepcopy
import multiprocessing
import time
from tqdm import tqdm

import global_consts as gc
from process_game import process_game
from format_json import format_json
from utils import get_all_files, remove_duplicated_targets


def process_json(game, prefix):
    """
    game = json.load(open(filename, 'r'))
    """
    game = remove_duplicated_targets(game)
    game = process_game(game, prefix)
    # game = label_all_cont(game)
    # game = add_prev_instruction(game)
    # game = add_prev_cmd(game)
    # game = add_base_frame(game)
    return game


def process_all(root, output, formatted, compact):
    assert formatted or compact

    src_files = get_all_files(root, '.json')
    if compact:
        compact_root = os.path.join(output, 'compact')

    if formatted:
        formatted_root = os.path.join(output, 'formatted')

    for f in tqdm(src_files):
        game = json.load(open(f, 'r'))
        game = process_json(game, f)

        if compact:
            dest_file = f.replace(root, compact_root)
            dirname = os.path.dirname(dest_file)
            if not os.path.exists(dirname):
                os.makedirs(dirname)
            json.dump(game, open(dest_file, 'w'))

        if formatted:
            dest_file = f.replace(root, formatted_root)
            dirname = os.path.dirname(dest_file)
            if not os.path.exists(dirname):
                os.makedirs(dirname)

            temp_json = tempfile.TemporaryFile('r+')
            json.dump(game, temp_json, indent=4)
            temp_json.seek(0)
            formatted = format_json(temp_json)
            with open(dest_file, 'w') as f:
                f.write(formatted)

    # t = time.time()
    # print('start processing')
    # pool = multiprocessing.Pool(60)
    # pool.map(_process_one, src_files)
    # print('time taken: %.2f' % time - t)


if __name__ == '__main__':
    src = '/private/home/hengyuan/scratch/rts_data/data3'
    dest = '/private/home/hengyuan/scratch/rts_data/data3_processed'
    process_all(src, dest, True, True)
