#!/usr/bin/env python3

import json
import os
from pathlib import Path
import re
import resource
import subprocess
import sys
import time

green = "\033[92m"
red = "\033[91m"
reset = "\033[0m"

benchmarks = [
    "population2000",
    "population_modified2000",
    "two_populations2000",
    "cont_switchpoint",
    "mixture",
    "hmm",
]

ram_limit = 12 * 1024 * 1024 * 1024
stack_size = 64 * 1024 * 1024
timeout = 3600
num_runs = 5 # several runs for warmup, but only the last run is kept

inference_time_re = re.compile("Total inference time: ([0-9.]*)s")
flags_re = re.compile("flags: (.*)")


def set_limits():
    resource.setrlimit(resource.RLIMIT_AS, (ram_limit, resource.RLIM_INFINITY))
    _, hard = resource.getrlimit(resource.RLIMIT_STACK)
    resource.setrlimit(resource.RLIMIT_STACK, (stack_size, hard))


def run_tool(tool, tool_command, path, flags):
    if not isinstance(tool_command, list):
        tool_command = [tool_command]
    try:
        command = tool_command + flags + [path]
        print(f"Running {command}...")
        start = time.perf_counter()
        completed = subprocess.run(
            command, timeout=timeout, capture_output=True, preexec_fn=set_limits
        )
        elapsed = time.perf_counter() - start
        output = (completed.stdout or b"").decode("utf-8")
        stderr = (completed.stderr or b"").decode("utf-8")
        exitcode = completed.returncode
        if exitcode != 0:
            print(
                f"Tool {tool} {red}FAILED{reset} (exit code {exitcode}) in {elapsed:.3f}s.\nStdout:\n{output}\nStderr:\n{stderr}"
            )
            return "crashed"
        else:
            m = inference_time_re.search(output)
            if m:
                inference_time = float(m.group(1))
            else:
                print(f"Tool {tool} {red}did not output its total inference time{reset}. Using the total running time instead...")
                inference_time = elapsed
        print(
            f"Tool {tool} inferred {path} in {inference_time:.4f}s"
        )
        return inference_time
    except subprocess.TimeoutExpired:
        print(f"Timemout of {timeout}s {red}expired{reset}.")
        return "timeout"


def genfer(benchmark):
    path = Path(f"benchmarks/{benchmark}.sgcl")
    command = "../genfer/target/release/genfer"
    if not path.is_file():
        return "n/a"
    m = flags_re.search(path.read_text())
    if m:
        flags = m.group(1).split()
    else:
        flags = []
    for run in range(num_runs):
        result = run_tool("Genfer", command, path, flags + ["--json", f"exact_output/{benchmark}_genfer.json"])
        print(result)
        if not isinstance(result, float):
            return result
    print(f"Last of {num_runs} runs of Genfer on {path} with flags {flags} was: {result}")
    print()
    return result


if __name__ == "__main__":
    start = time.time()
    own_path = Path(sys.argv[0]).parent
    os.chdir(own_path)
    all_results = {}
    for benchmark in benchmarks:
        print(f"Benchmarking {benchmark}")
        print("============")
        result = genfer(benchmark)
        print()
        print()
    end = time.time()
    elapsed = end - start
    print(f"{green}Benchmarking finished successfully in {elapsed:.1f}s.{reset}")

