"""
Filter units sampled via sample_units.py to only include units for which the stimuli
generation was successful. It also re-orders the units to match the order used by an
earlier version of the create_task_structure_json.py script.
"""

import argparse
import glob
import json
import os


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--units", required=True, help="Path to units json file.")
    parser.add_argument(
        "-s",
        "--source-folder",
        required=True,
        help="Path to source stimuli (/path/to/stimuli/).",
    )
    parser.add_argument(
        "-o",
        "--output",
        required=True,
        help="Where to save cleaned up units json file.",
    )
    parser.add_argument(
        "-c",
        "--condition",
        required=True,
        type=str,
        choices=["natural", "optimized"],
        help="Which condition to use for trials.",
    )
    parser.add_argument(
        "-m",
        "--model",
        required=True,
        type=str,
        help="Which model to select from the stimuli folder.",
    )
    parser.add_argument(
        "--extra_name",
        type=str,
        default="",
        help="Extra name to be added to the model when finding stimuli.",
    )
    args = parser.parse_args()

    assert args.output != args.units, "Output and input file must be different."

    available_units = glob.glob(
        os.path.join(
            args.source_folder,
            f"{args.model}_{args.extra_name}" if args.extra_name else args.model,
            "*",
            "channel_*",
            f"{args.condition}_images",
        )
    )

    # Load json file
    with open(args.units, "r") as f:
        sampled_units = json.load(f)["units"]
    n_original_units = len(sampled_units)

    # Filter units. It's important to not reverse the order in which things are done
    # to ensure compatibility with an older version of the
    # create_task_structure_json.py script.
    for i in range(len(available_units) - 1, -1, -1):
        unit_name = f"{available_units[i].split('/')[-3]}__{available_units[i].split('/')[-2].split('_')[1]}"
        if unit_name not in sampled_units:
            print("deleted")
            del available_units[i]

    available_units = [
        f"{au.split('/')[-3]}__{au.split('/')[-1].split('_')[1]}"
        for au in available_units
    ]

    n_available_units = len(available_units)

    print(f"Reduced {n_original_units} units to {n_available_units} units.")

    units = {"units": available_units}

    # dump task-structures to json-file
    with open(args.output, "w") as f:
        json.dump(units, f)


if __name__ == "__main__":
    main()
