import argparse
import json
import re
import os
import glob
from tqdm import tqdm
from tabulate import tabulate
from collections import defaultdict

from rlpytorch.behavior_clone.replay import Replay
from rlpytorch.behavior_clone.utils import parse_dataset


def print_table(name, header, rows):
    table = [header] + rows
    print('')
    print(name)
    print(tabulate(table, headers='firstrow'))
    print('')


def eval_game_completeness(replays):
    num_games = len(replays)
    num_finished = sum(int(replay.is_finished()) for replay in replays)
    num_unfinished = sum(int(replay.is_unfinished()) for replay in replays)
    num_empty = sum(int(replay.is_empty()) for replay in replays)

    num_commands = 0
    for replay in replays:
        num_commands += len(replay.cmd2act)

    rows = [
        ['games', str(num_games)],
        ['commands', str(num_commands)],
        ['finished', '%.2f%% (%d)' % (100. * num_finished / num_games, num_finished)],
        ['unfinished', '%.2f%% (%d)' % (100. * num_unfinished / num_games, num_unfinished)],
        ['empty', '%.2f%% (%d)' % (100. * num_empty / num_games, num_empty)],
    ]
    print_table(
        name='game completeness stats',
        header=['attribute', 'value'],
        rows=rows)


def eval_global_win_rate(replays):
    num_games = sum(int(replay.is_finished()) for replay in replays)
    num_wins = sum(int(replay.is_win()) for replay in replays)
    num_loses = sum(int(replay.is_lose()) for replay in replays)

    rows = [
        ['games', str(num_games)],
        ['wins', '%.2f%% (%d)' % (100. * num_wins / num_games, num_wins)],
        ['loses', '%.2f%% (%d)' % (100. * num_loses / num_games, num_loses)],
    ]
    print_table(
        name='global win rate',
        header=['attribute', 'value'],
        rows=rows)


def eval_win_rate(replays):
    num_games = sum(int(replay.is_finished()) for replay in replays)
    num_wins = defaultdict(int)
    num_games = defaultdict(int)
    for replay in replays:
        if replay.is_finished():
            assert replay.opponent is not None
            num_games[replay.opponent] += 1
            if replay.is_win():
                num_wins[replay.opponent] += 1

    rows = []
    for opponent in num_games:
        row = [opponent, '%.2f%% (%d/%d)' % (100. * num_wins[opponent] / num_games[opponent],
            num_wins[opponent], num_games[opponent])]
        rows.append(row)

    print_table(
        name='per opponent win rate',
        header=['opponent', 'value'],
        rows=rows)


def eval_avg_unit_presence(replays):
    num_games = 0
    num_appearances = defaultdict(int)
    for replay in replays:
        if replay.is_empty():
            continue
        num_games += 1
        present = set()
        for (_, acts, _) in replay.cmd2act:
            for act in acts:
                if act.act_type == 'build':
                    present.add(act.extra)
        for unit_type in present:
            num_appearances[unit_type] += 1

    rows = []
    for unit_type in num_appearances:
        cnt = num_appearances[unit_type]
        row = [unit_type, '%.2f%% (%d/%d)' % (100. * cnt / num_games, cnt, num_games)]
        rows.append(row)

    print_table(
        name='build unit appearances',
        header=['unit type', 'value'],
        rows=rows)


def eval_avg_unit_count(replays):
    num_games = 0
    total_count = defaultdict(int)
    for replay in replays:
        if replay.is_empty():
            continue
        num_games += 1
        for (_, acts, _) in replay.cmd2act:
            for act in acts:
                if act.act_type == 'build':
                    total_count[act.extra] += 1

    rows = []
    for unit_type in total_count:
        cnt = total_count[unit_type]
        row = [unit_type, '%.2f' % (cnt / num_games)]
        rows.append(row)

    print_table(
        name='average build per game',
        header=['unit type', 'value'],
        rows=rows)


def eval_command_execution(replays):
    num_correct = sum(replay.num_correct for replay in replays)
    num_incorrect = sum(replay.num_incorrect for replay in replays)
    num_actions = sum(replay.num_actions for replay in replays)

    num_covered = sum(replay.num_covered for replay in replays)
    num_issued = sum(replay.num_issued for replay in replays)

    num_games = 0
    games_with_warnings = 0
    total_num_warnings = 0
    total_count = defaultdict(int)
    for replay in replays:
        if replay.is_empty():
            continue
        num_games += 1
        if replay.warnings > 0:
            games_with_warnings += 1
        total_num_warnings += replay.warnings
        for (_, acts, _) in replay.cmd2act:
            for act in acts:
                if act.act_type == 'build':
                    total_count[act.extra] += 1

    rows = [
        ['correct actions', '%.2f%% (%d/%d)' % (100. * num_correct / num_actions,
            num_correct, num_actions)],
        ['incorrect actions', '%.2f%% (%d/%d)' % (100. * num_incorrect / num_actions,
            num_incorrect, num_actions)],
        ['covered instructions', '%.2f%% (%d/%d)' % (100. * num_covered / num_issued,
            num_covered, num_issued)],
        ['games with warnings', '%.2f%% (%d/%d)' % (100. * games_with_warnings / num_games,
            games_with_warnings, num_games)],
        ['avg num warnings per game', '%.2f (%d/%d)' % (total_num_warnings / num_games,
            total_num_warnings, num_games)],


    ]
    print_table(
        name='instructions/commands stats',
        header=['statistic', 'value'],
        rows=rows)


def compute_ngrams(s, n):
    ngrams = []
    words = s.split(' ')
    for i in range(len(words) - n + 1):
        ngram = ' '.join(words[i: i + n])
        ngrams.append(ngram)
    return ngrams


def eval_command_ngrams(replays, n):
    counts = defaultdict(int)
    for replay in replays:
        for (_, _, cmd) in replay.cmd2act:
            ngrams = compute_ngrams(cmd, n)
            for ngram in ngrams:
                counts[ngram] += 1

    keys = counts.keys()
    keys = sorted(keys, key=lambda k: -counts[k])[:20]

    rows = []
    for key in keys:
        row = [key, counts[key]]
        rows.append(row)

    print_table(
        name='most common %dgrams' % n,
        header=['%dgram' % n, 'count'],
        rows=rows)


def eval_most_common_commands(replays, n=100):
    counts = defaultdict(int)
    all_cmds = set()
    for replay in replays:
        for (_, _, cmd) in replay.cmd2act:
            counts[cmd.strip()] += 1
            all_cmds.add(cmd.strip())

    keys = counts.keys()
    keys = sorted(keys, key=lambda k: -counts[k])[:n]

    rows = [['total unique', len(all_cmds)]]
    for key in keys:
        row = [key, counts[key]]
        rows.append(row)

    print_table(
        name='most common commands',
        header=['command', 'count'],
        rows=rows)


def eval_most_common_commands_by_coach(replays, n=100):
    counts = defaultdict(int)
    for replay in replays:
        for (_, _, cmd) in replay.cmd2act:
            key = '%s_%s' % (replay.coach_id, cmd.strip())
            counts[key] += 1

    keys = counts.keys()
    keys = sorted(keys, key=lambda k: -counts[k])[:n]

    rows = []
    for key in keys:
        coach_id = key.split('_')[0]
        cmd = key.split('_')[1]
        row = [coach_id, cmd, counts[key]]
        rows.append(row)

    print_table(
        name='most common commands by coach',
        header=['coach id', 'command', 'count'],
        rows=rows)


def filter_coach(replay_paths, target_coach_id):
    filtered = []
    for replay_path in replay_paths:
        coach_id = replay_path.split('_')[2]
        if coach_id == target_coach_id:
            filtered.append(replay_path)
    return filtered

def filter_player(replay_paths, target_player_id):
    filtered = []
    for replay_path in replay_paths:
        player_id = replay_path.split('_')[1]
        if player_id == target_player_id:
            filtered.append(replay_path)
    return filtered




def main():
    parser = argparse.ArgumentParser(description='dataset evaluator')
    parser.add_argument('--dataset-root', type=str, default='train.json')
    parser.add_argument('--coach-id', default=None, dest='coach_id', type=str)
    parser.add_argument('--player-id', default=None, dest='player_id', type=str)
    args = parser.parse_args()

    assert (args.coach_id is None) or (args.player_id is None)
    replay_paths = glob.glob(os.path.join(args.dataset_root, '**', '*.rep'), recursive=True)
    if args.coach_id is not None:
        replay_paths = filter_coach(replay_paths, args.coach_id)
    if args.player_id is not None:
        replay_paths = filter_player(replay_paths, args.player_id)

    if len(replay_paths) == 0:
        print('no recorded games')
        return

    replays = parse_dataset(replay_paths)
    if args.coach_id is not None:
        for replay in replays:
            assert replay.coach_id == args.coach_id
    if args.player_id is not None:
        for replay in replays:
            assert replay.player_id == args.player_id


    if (args.player_id is None) and (args.coach_id is None):
        eval_game_completeness(replays)
        eval_global_win_rate(replays)
        eval_win_rate(replays)
        eval_avg_unit_presence(replays)
        eval_avg_unit_count(replays)
        eval_command_execution(replays)
        for i in [1, 2, 3]:
            eval_command_ngrams(replays, i)
        eval_most_common_commands(replays)
        eval_most_common_commands_by_coach(replays)
    elif args.player_id is not None:
        eval_game_completeness(replays)
        eval_global_win_rate(replays)
        eval_win_rate(replays)
        eval_avg_unit_presence(replays)
        eval_avg_unit_count(replays)
        eval_command_execution(replays)
    elif args.coach_id is not None:
        eval_game_completeness(replays)
        eval_global_win_rate(replays)
        eval_win_rate(replays)
        eval_avg_unit_presence(replays)
        eval_avg_unit_count(replays)
        eval_command_execution(replays)
        for i in [1, 2, 3]:
            eval_command_ngrams(replays, i)
        eval_most_common_commands(replays)
        eval_most_common_commands_by_coach(replays)


if __name__ == '__main__':
    main()

