import argparse
import datetime, time, subprocess
import os, signal, psutil

SLEEPTIME = 60
MODEL_NAME = ("gpt-4-1106-base",)  # "gpt-4-0613-base", "gpt-4-32k-0613-base")
API_BASE = (
    "https://3-openai-australiaeast.openai.azure.com/",
    "https://3-openai-canadaeast.openai.azure.com/",
    "https://3-openai-francecentral.openai.azure.com/",
    "https://3-openai-swedencentral.openai.azure.com/",
    # "https://3-openai-switzerlandnorth.openai.azure.com/",
    "https://5-openai-australiaeast.openai.azure.com/",
    "https://5-openai-canadaeast.openai.azure.com/",
    "https://5-openai-francecentral.openai.azure.com/",
    "https://5-openai-swedencentral.openai.azure.com/",
    # "https://5-openai-switzerlandnorth.openai.azure.com/",
)

SLEEP_TIME = 20


def submit_job_to_endpoint(dataset_name, chunk_index, endpoint_index, model_name):
    proc = subprocess.Popen(
        [
            "python",
            "./gpt_ranking/worker.py",
            "--dataset",
            dataset_name,
            "--chunk_id",
            f"{chunk_index}",
            "--endpoint",
            f"{endpoint_index}",
            "--model",
            f"{model_name}",
        ]
    )
    return proc


def wait_child(signum, frame):
    try:
        while True:
            cpid, _ = os.waitpid(-1, os.WNOHANG)
            if cpid == 0:
                break
    except:
        pass


def main(dataset_name, start_chunk, end_chunk):
    signal.signal(signal.SIGCHLD, wait_child)

    endpoint_status = {n: {m: None for m in MODEL_NAME} for n in API_BASE}
    remaining_tasks = [n for n in range(start_chunk, end_chunk)]
    while True:
        for endpoint, endpoint_models in endpoint_status.items():
            for model, task in endpoint_models.items():
                if task is not None:
                    if not psutil.pid_exists(task["proc"].pid):
                        print(
                            f'[{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}] Info: Chunk {task["chunk_id"]} on ({endpoint}, {model}) finished.'
                        )
                        endpoint_status[endpoint][model] = None

        idle_endpoint_model = [
            (endpoint, model)
            for endpoint, endpoint_models in endpoint_status.items()
            for model, task in endpoint_models.items()
            if task is None
        ]

        if len(remaining_tasks) == 0 and all(
            task is None
            for endpoint_models in endpoint_status.values()
            for task in endpoint_models.values()
        ):
            print(
                f'[{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}] Info: No remaining tasks existing.'
            )
            break
        elif len(remaining_tasks) > 0:
            print(
                f'[{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}] Info: Found {len(remaining_tasks)} remaining tasks.'
            )
            if len(idle_endpoint_model) <= 0:
                print(
                    f'[{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}] Info: Too many jobs running. Wait for {SLEEP_TIME} seconds.'
                )
                time.sleep(SLEEP_TIME)
                continue

            counter = 0
            for chunk_index, (endpoint, model) in zip(
                list(remaining_tasks)[: len(idle_endpoint_model)], idle_endpoint_model
            ):
                print(
                    f'[{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}] Info: Submitting Chunk {chunk_index} to {endpoint}, {model}.'
                )
                proc = submit_job_to_endpoint(
                    dataset_name, chunk_index, endpoint, model
                )
                endpoint_status[endpoint][model] = {
                    "chunk_id": chunk_index,
                    "proc": proc,
                }
                remaining_tasks.remove(chunk_index)
                counter += 1
                time.sleep(SLEEP_TIME / 20)
            if counter > 0:
                print(
                    f'[{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}] Info: Submitted {counter} slurm jobs.'
                )
            time.sleep(SLEEP_TIME)


if __name__ == "__main__":
    argparser = argparse.ArgumentParser()
    argparser.add_argument(
        "--dataset",
        type=str,
        default="GSM8K",
        help="Name of the Dataset",
    )
    argparser.add_argument(
        "--start_chunk",
        type=int,
        default=0,
        help="Name of the Dataset",
    )
    argparser.add_argument(
        "--end_chunk",
        type=int,
        default=20,
        help="Name of the Dataset",
    )
    args = argparser.parse_args()

    main(args.dataset, args.start_chunk, args.end_chunk)
