import os
import json
from collections import defaultdict
import pprint
from tqdm import tqdm

import matplotlib.pyplot
matplotlib.pyplot.switch_backend('agg')
import matplotlib.pyplot as plt

import utils


def get_num_targets(frames):
    utils.remove_duplicated_targets(frames)
    num_targets = 0
    for entry in frames[:-1]:
        if entry is None or 'targets' not in entry:
            continue

        num_targets += len(entry['targets'])
    return num_targets


def num_targets_to_winrate(files, max_target=200):
    stat = defaultdict(lambda: defaultdict(int))

    for f in tqdm(files):
        # print(f)
        frames = json.load(open(f, 'r'))
        num_targets = get_num_targets(frames)
        if num_targets > max_target:
            continue

        reward = int(frames[-1])
        stat[num_targets][reward] += 1

    pprint.pprint(stat)
    plot(stat, '#targets', 'num_target_to_winrate.png')


def get_instructions(frames):
    insts = []
    for entry in frames[:-1]:
        if entry is None:
            continue

        inst = entry['instruction']
        if len(insts) == 0 or insts[-1] != inst:
            insts.append(inst)

    return insts


def num_inst_to_winrate(files, max_inst=35):
    stat = defaultdict(lambda: defaultdict(int))
    for f in tqdm(files):
        frames = json.load(open(f, 'r'))
        insts = get_instructions(frames)
        reward = int(frames[-1])
        num_inst = min(max_inst, len(insts))
        stat[num_inst][reward] += 1

    pprint.pprint(stat)
    plot(stat, '#instructions', 'num_inst_to_winrate.png')


def plot(stat, x_label, figname):
    x = []
    win_rate = []
    num_games = []
    keys = sorted(stat.keys())
    for key in keys:
        val = stat[key]
        x.append(key)
        win = val[1]
        loss = val[-1]
        win_rate.append(win / (win + loss))
        num_games.append(win + loss)

    fig, ax1 = plt.subplots()
    fig.set_size_inches(10, 10)
    color = 'tab:red'
    ax1.set_xlabel(x_label)
    ax1.set_ylabel('win rate', color=color)
    ax1.plot(x, win_rate, color=color)
    ax1.tick_params(axis='y', labelcolor=color)

    ax2 = ax1.twinx()  # instantiate a second axes that shares the same x-axis

    color = 'tab:blue'
    ax2.set_ylabel('num games', color=color)  # we already handled the x-label with ax1
    ax2.plot(x, num_games, color=color)
    ax2.tick_params(axis='y', labelcolor=color)

    fig.tight_layout()
    plt.savefig(figname)


def filter_bad_replays(files, min_insts, min_targets):
    filtered_files = []
    removed_files = []
    for f in tqdm(files):
        frames = json.load(open(f, 'r'))
        insts = get_instructions(frames)
        num_targets = get_num_targets(frames)
        if len(insts) >= min_insts and num_targets >= min_targets:
            filtered_files.append(f)
        else:
            removed_files.append(f)

    return filtered_files, removed_files


def get_bad_replays(files):
    bad_files = []
    for f in tqdm(files):
        frames = json.load(open(f, 'r'))
        for e in frames:
            if e == None or type(e) == float:
                continue

            if len(e['instruction'.strip()]) == 0:
                print('empty instruction:', f)

            if len(e['my_units']) == 0:
                print('no units:', f)


if __name__ == '__main__':
    root = '/private/home/hengyuan/scratch/rts_data/data1'
    files = utils.get_all_files(root, '.json')

    # get_bad_replays(files)

    # num_inst_to_winrate(files)
    # num_targets_to_winrate(files)

    filtered_files, removed_files = filter_bad_replays(files, 5, 30)
    print('before filtering:', len(files))
    print('after filtering:', len(filtered_files))
    removed_replays = []
    for f in removed_files:
        f = f.replace(root, '/private/home/hengyuan/rts-replays/replays')
        f = f.replace('.p0.json', '')
        removed_replays.append(f)

    json.dump(removed_replays, open('removed_replays.txt', 'w'), indent=4)
