import logging
import os
import re
from dataclasses import dataclass
from typing import Union, Sequence

import torch
import transformers
from accelerate import DistributedType, Accelerator

from rtfm.arguments import TrainingArguments, DataArguments, ModelArguments
from rtfm.train_utils import OPTIMIZER_STATE_PT, SCHEDULER_STATE_PT, SCALER_STATE_PT


def init_trackers_if_main_process(
    accelerator: Accelerator,
    training_arguments: TrainingArguments,
    model_arguments: ModelArguments,
    data_arguments: DataArguments,
    tags: Union[str, Sequence[str]],
):
    """Initializes the Accelerator trackers as necessary.

    If this is the main process, this function will initialize the trackers.
    If this is not the main process, it is still sae to call this function, but the trackers
    will not be initialized (to prevent duplicate logging across processes)."""
    log_to_wandb = (
        "wandb" in training_arguments.report_to and accelerator.is_main_process
    )
    if log_to_wandb:
        config = {
            **training_arguments.__dict__,
            **data_arguments.__dict__,
            **model_arguments.__dict__,
        }
        accelerator.init_trackers(
            project_name=training_arguments.wandb_project_name,
            config=config,
            init_kwargs={
                "wandb": {
                    "id": training_arguments.run_name
                    if training_arguments.run_name
                    else None,
                    "tags": tags,
                }
            },
        )
    return


def fetch_auth_token() -> Union[str, None]:
    for k in ("HF_TOKEN", "HUGGING_FACE_HUB_TOKEN"):
        if os.environ.get(k):
            return os.environ[k]


def prepare_for_potentially_distributed_training(
    accelerator, model, optimizer, lr_scheduler, train_dataloader
):
    is_fsdp = accelerator.distributed_type == DistributedType.FSDP
    accelerator.print(
        f"preparing model/optimizer/scheduler/dataloader for training; fsdp is {is_fsdp}"
    )

    if is_fsdp:
        # Prepare model/optimizer/scheduler; do NOT prepare DataLoader;
        # see https://github.com/huggingface/transformers/issues/26548
        model, optimizer, lr_scheduler = accelerator.prepare(
            model, optimizer, lr_scheduler
        )
    else:
        model, optimizer, lr_scheduler, train_dataloader = accelerator.prepare(
            model, optimizer, lr_scheduler, train_dataloader
        )
    return model, optimizer, lr_scheduler, train_dataloader


def load_pretrained(
    model_name_or_path: str,
    training_args: TrainingArguments,
    model_cls=transformers.AutoModelForCausalLM,
):
    if not torch.cuda.is_available():
        model = model_cls.from_pretrained(
            model_name_or_path,
            trust_remote_code=True,
            torch_dtype=torch.bfloat16 if training_args.bf16 else torch.float32,
            low_cpu_mem_usage=True,
            cache_dir=None,
        )
    else:
        # TODO(jpgard): do we need this context manager at all for accelerate?
        with torch.device("cuda"):
            model = model_cls.from_pretrained(
                model_name_or_path,
                trust_remote_code=True,
                torch_dtype=torch.bfloat16 if training_args.bf16 else torch.float32,
                cache_dir=None,
                attn_implementation="flash_attention_2"
                if training_args.use_flash_attention_2
                else None,
            )

    if training_args.gradient_checkpointing:
        model.gradient_checkpointing_enable()

    return model


@dataclass
class ResumeCheckpointInfo:
    """Represents the info needed for resuming from a checkpoint."""

    model_name_or_path: str  # path compatible with loading via .from_pretrained()
    optimizer_state: str
    scheduler_state: str
    scaler_state: Union[str, None]
    step: int


def parse_resume_from_checkpoint(
    training_arguments: TrainingArguments,
) -> ResumeCheckpointInfo:
    assert os.path.exists(
        training_arguments.output_dir
    ), f"cannot resume from directory {training_arguments.output_dir}, does not exist"

    if (
        isinstance(training_arguments.resume_from_checkpoint, bool)
        or training_arguments.resume_from_checkpoint == "True"
        or training_arguments.resume_from_checkpoint == ""
    ):
        # Find the most recent checkpoint in the directory.
        dirs = [f.name for f in os.scandir(training_arguments.output_dir) if f.is_dir()]
        dirs.sort(key=os.path.getctime)
        path = dirs[-1]

    elif isinstance(training_arguments.resume_from_checkpoint, str):
        assert os.path.exists(
            training_arguments.resume_from_checkpoint
        ), f"cannot resume from checkpoint {training_arguments.resume_from_checkpoint}, does not exist"
        path = training_arguments.resume_from_checkpoint

    else:
        raise ValueError(f"unknown value {training_arguments.resume_from_checkpoint}")
    # the directory path; can use .from_pretrained() on this
    model_pretrained = path
    assert "step_" in path
    step = int(re.search("step_(\\d+)", path).group(1))

    optimizer_state = os.path.join(path, OPTIMIZER_STATE_PT)
    assert os.path.exists(optimizer_state), f"{optimizer_state} does not exist"
    scheduler_state = os.path.join(path, SCHEDULER_STATE_PT)
    assert os.path.exists(scheduler_state), f"{scheduler_state} does not exist"
    scaler_state = os.path.join(path, SCALER_STATE_PT)
    if not os.path.exists(scaler_state):
        logging.warning(
            f"scaler state file {scaler_state} does not exist; this is expected if the model at {path} "
            "was not trained with AMP but is not expected if the model was trained using AMP."
            "Resuming an AMP training run without this can lead to errors or loss spikes."
        )
        scaler_state = None

    return ResumeCheckpointInfo(
        model_pretrained, optimizer_state, scheduler_state, scaler_state, step
    )
