import hydra
import jax
import omegaconf
from train import create_logger

from compass.trainers.slowrl_validation import slowrl_validate
from compass.trainers.slowrl_validation_jobshop import slowrl_validate as slowrl_validate_jssp
from compass.utils.logger import EnsembleLogger, NeptuneLogger, TerminalLogger


def create_logger(cfg) -> EnsembleLogger:
    loggers = []

    if "terminal" in cfg.logger:
        loggers.append(TerminalLogger(**cfg.logger.terminal))

    if "neptune" in cfg.logger:
        neptune_config = {}
        neptune_config["name"] = cfg.logger.neptune.name
        neptune_config["project"] = cfg.logger.neptune.project
        neptune_config["tags"] = [
            f"{cfg.algo_name}",
            "slowrl",
            f"{cfg.env_name}",
            "final-exp",
        ]
        neptune_config["parameters"] = cfg

        loggers.append(NeptuneLogger(**neptune_config))

    # return the loggers
    return EnsembleLogger(loggers)


@hydra.main(
    config_path="config",
    version_base=None,
    config_name="config_exp",
)
def run(cfg: omegaconf.DictConfig) -> None:
    # create the name of the run's directory - used for logging and checkpoints
    run_subdirectory = (
        str(cfg.env_name)
        + "/"
        + str(cfg.algo_name)
        + "/"
        + f"bs{cfg.batch_size}_tss{cfg.training_sample_size}"
        + f"_ga{cfg.optimizer.num_gradient_accumulation_steps}"
        + f"_bd{cfg.behavior_dim}"
        + f"_ba{cfg.behavior_amplification}_seed{cfg.seed}/"
    )

    # update base name with complete name
    cfg.checkpointing.directory = cfg.checkpointing.directory + run_subdirectory
    cfg.slowrl.checkpointing.restore_path = (
        cfg.slowrl.checkpointing.restore_path + run_subdirectory
    )

    # Check and configure the available devices.
    behavior_dim = cfg.behavior_dim
    slowrl_cfg = cfg.slowrl

    available_devices = len(jax.local_devices())
    if slowrl_cfg.num_devices < 0:
        slowrl_cfg.num_devices = available_devices
        print(f"Using {available_devices} available device(s).")
    else:
        assert (
            available_devices >= slowrl_cfg.num_devices
        ), f"{slowrl_cfg.num_devices} devices requested but only {available_devices} available."

    # create a logger
    cfg.logger.neptune.name = "slowrl-" + cfg.logger.neptune.name
    logger = create_logger(cfg)

    key = jax.random.PRNGKey(slowrl_cfg.problem_seed)

    if cfg.env_name == "jssp":
        metrics = slowrl_validate_jssp(
            random_key=key,
            cfg=slowrl_cfg,
            params=None,
            behavior_dim=behavior_dim,
            logger=logger,
        )
    else:
        metrics = slowrl_validate(
            random_key=key,
            cfg=slowrl_cfg,
            params=None,
            behavior_dim=behavior_dim,
            logger=logger,
        )
    print(metrics)


if __name__ == "__main__":
    run()
