from __future__ import annotations

from itertools import islice
import logging
import math
import time
from typing import Any, Dict, Optional, Tuple
from pathlib import Path

import numpy as np
import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
import wandb

from .data import DictMemmapWriter
from .train import Trainer, SpeedMonitor
from .eval import Evaluator, generative_tasks
from .tokenizer import Tokenizer
from .torch_util import get_world_size, move_to_device, barrier

from hf_olmo.modeling_olmo import OLMoForCausalLM

log = logging.getLogger(__name__)


class DataSelector(Trainer):

    def add_reference_model(self, reference_trainer: Trainer):
        # Allows us to add reference models to trainer
        # Can load from checkpoint in train script as separate trainer
        # Then can just call model_forward where needed in train_batch
        # Default is to load ref scores from disk, but this is another option
        if not hasattr(self, "reference_models"):
            self.reference_models = []
        self.reference_models.append(reference_trainer)

    def score(self, reference_loss: torch.Tensor, learner_loss: torch.Tensor):
        # Select tokens/seqs
        if self.cfg.score is None or self.cfg.score == "rho":
            score = (learner_loss - reference_loss).detach()
        elif self.cfg.score == "ref":
            score = -reference_loss.detach()
        return score

    def train_batch(self, batch: Dict[str, Any]) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        # Split into micro-batches.
        micro_batches = self.split_batch(batch)

        # In case this helps with memory utilization.
        del batch

        ce_batch_loss = torch.tensor(0.0, device=self.device)
        full_batch_loss = torch.tensor(0.0, device=self.device)
        reference_batch_loss = torch.tensor(0.0, device=self.device)

        batch_learner_losses = []
        multi_reference_freq = []

        z_batch_loss = None if not self.cfg.softmax_auxiliary_loss else torch.tensor(0.0, device=self.device)
        for micro_batch in micro_batches:
            with torch.autocast("cuda", enabled=True, dtype=self.cfg.autocast_precision):
                if self.cfg.granularity == "token":
                    tokens = True
                else:
                    assert self.cfg.granularity == "sequence"
                    tokens = False

                if not (self.cfg.fix_learner and not tokens):
                    # Run forward pass.
                    ce_loss, z_loss = self.model_forward(
                        micro_batch,
                        compute_z_loss=self.cfg.softmax_auxiliary_loss,
                        loss_reduction="none",
                        return_logits=False,
                    )
                    if tokens:
                        ce_loss = ce_loss.flatten()
                        if self.cfg.softmax_auxiliary_loss:
                            z_loss = z_loss.flatten()
                    else:
                        ce_loss = ce_loss.mean(dim=-1)
                        if self.cfg.softmax_auxiliary_loss:
                            z_loss = z_loss.mean(dim=-1)

                    # Log full batch loss
                    full_batch_loss += ce_loss.mean().detach() / len(micro_batches)

                # Add reference model scores
                if hasattr(self, "reference_models"):
                    # Run forward pass on reference model with no grad
                    with torch.no_grad() if not self.cfg.update_reference else torch.enable_grad():
                        ref_scores = []
                        for model in self.reference_models:
                            ref_score, ref_z_loss = model.model_forward(
                                micro_batch,
                                compute_z_loss=True,
                                loss_reduction="none",
                                return_logits=False,
                            )
                            num_ref_keys = len([key for key in micro_batch if "ref_score" in key])
                            micro_batch[f"ref_score_{num_ref_keys}"] = ref_score
                            if self.cfg.update_reference:
                                if tokens:
                                    ref_score = ref_score.flatten() + ref_z_loss.flatten()
                                else:
                                    ref_score = ref_score.mean(dim=-1) + ref_z_loss.mean(dim=-1)
                                ref_scores.append(ref_score)

                # Load reference scores to select data
                if self.cfg.method == "load_score":
                    ref_keys = [key for key in micro_batch if "ref_score" in key]

                    if self.cfg.fix_learner:  # Load reference score as learner
                        learner_loss = micro_batch["ref_score_0"]
                        if tokens:
                            learner_loss = learner_loss.flatten()
                        else:
                            learner_loss = learner_loss.mean(dim=-1)
                        ref_keys = [key for key in ref_keys if key != "ref_score_0"]
                    else:
                        learner_loss = ce_loss

                    scores = []
                    for key in ref_keys:
                        reference_loss = micro_batch[key]  # Score is loaded with no reduction
                        if tokens:
                            reference_loss = reference_loss.flatten()
                        else:
                            reference_loss = reference_loss.mean(dim=-1)
                        score = self.score(reference_loss, learner_loss)
                        scores.append(score)
                    k = int(self.cfg.select_frac * len(reference_loss))

                    # Combine scores if multiple scores exist
                    if len(scores) == 1:
                        score = scores[0]
                        select_idx = torch.topk(score, k, largest=True).indices
                    elif self.cfg.score_combination == "max":
                        score = torch.stack(scores).max(dim=0).values
                        select_idx = torch.topk(score, k, largest=True).indices

                        # log index of selected reference
                        multi_reference_select_idx = torch.argmax(torch.stack(scores), dim=0)
                        multi_reference_freq_batch = torch.bincount(
                            multi_reference_select_idx, minlength=len(scores)
                        )
                        multi_reference_freq_batch = (
                            multi_reference_freq_batch.float() / multi_reference_freq_batch.sum()
                        )
                        multi_reference_freq.append(multi_reference_freq_batch.detach().cpu())
                    elif self.cfg.score_combination == "sum":
                        score = torch.stack(scores).sum(dim=0)
                        select_idx = torch.topk(score, k, largest=True).indices
                    elif self.cfg.score_combination == "mix":
                        chunk_size = learner_loss.shape[0] // len(scores)
                        idxs = []
                        for i in range(len(scores)):
                            chunk_score = scores[i][i * chunk_size : (i + 1) * chunk_size]
                            select_chunk = torch.topk(chunk_score, k // len(scores), largest=True).indices
                            idxs.append(select_chunk + i * chunk_size)
                        select_idx = torch.cat(idxs, dim=0)

                    # Log reference loss (note: will be incorrect with multiple references)
                    reference_batch_loss += reference_loss.mean().detach() / len(micro_batches)

                elif self.cfg.method == "full":
                    reference_loss = None
                    score = torch.zeros_like(ce_loss)
                    select_idx = torch.arange(len(score), device=self.device)
                    k = len(score)

                else:
                    raise ValueError(f"Unknown method: {self.cfg.method}")

                if self.cfg.select_random:
                    select_idx = torch.randperm(len(score), device=self.device)[:k]

                if not (self.cfg.fix_learner and not tokens):
                    # Select by select_idx from the loss and z_loss
                    ce_loss = ce_loss[select_idx].mean() / len(micro_batches)
                    z_loss = z_loss[select_idx].mean() / len(micro_batches)
                else:
                    full_batch_loss += learner_loss.mean().detach() / len(micro_batches)
                    # Efficient forward pass on selected seqs
                    micro_batch = {k: v[select_idx] for k, v in micro_batch.items()}
                    ce_loss, z_loss = self.model_forward(
                        micro_batch,
                        compute_z_loss=self.cfg.softmax_auxiliary_loss,
                        loss_reduction="none",
                        return_logits=False,
                    )
                    ce_loss = ce_loss.mean() / len(micro_batches)
                    if self.cfg.softmax_auxiliary_loss:
                        z_loss = z_loss.mean() / len(micro_batches)

                # Update overall CE batch loss.
                ce_batch_loss += ce_loss.detach()

                # Get final loss to optimize for.
                if self.cfg.softmax_auxiliary_loss:
                    assert z_loss is not None
                    assert z_batch_loss is not None
                    loss = ce_loss + z_loss
                    # Update overall Z batch loss.
                    z_batch_loss += z_loss.detach()
                else:
                    loss = ce_loss

                # optionally train reference
                # TODO: add logging
                if self.cfg.update_reference:
                    ref_losses = []
                    for ref_score in ref_scores:
                        ref_loss = ref_score[select_idx].mean()
                        ref_losses.append(ref_loss / len(micro_batches))

                # In case this helps with memory utilization.
                del micro_batch

            # Run backward pass.
            loss.backward()
            if self.cfg.update_reference:
                for ref_loss in ref_losses:
                    ref_loss.backward()

        # Cat for logging purposes
        if batch_learner_losses:
            batch_learner_losses = torch.concatenate(batch_learner_losses, dim=0)
        if multi_reference_freq:
            multi_reference_freq = torch.stack(multi_reference_freq, dim=0).mean(dim=0)
        return (
            ce_batch_loss,
            z_batch_loss,
            reference_batch_loss,
            full_batch_loss,
            batch_learner_losses,
            multi_reference_freq,
        )

    def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) -> Dict[str, float]:
        metrics: Dict[str, float] = {}
        batch_data: Dict[str, torch.Tensor] = {}

        # Write data-indices to file.
        if self.indices_file is not None and "index" in batch:
            indices = "\t".join(str(int(i)) for i in batch["index"])
            self.indices_file.write(f"{self.global_step}\t{indices}\n")

        # Zero-gradients.
        self.optim.zero_grad(set_to_none=True)
        if self.cfg.update_reference:
            for ref_model in self.reference_models:
                ref_model.optim.zero_grad(set_to_none=True)

        # Move tensors to the right device.
        batch = move_to_device(batch, self.device)

        # Run forward-backward pass.
        (
            ce_batch_loss,
            z_batch_loss,
            reference_batch_loss,
            full_batch_loss,
            batch_learner_losses,
            multi_reference_freq,
        ) = self.train_batch(batch)
        if not self.cfg.sft:
            batch_data["learner_score"] = batch_learner_losses
            batch_data["index"] = batch["index"].detach().cpu()

        # Collect loss, potentially reducing over all ranks.
        if reduce_global_loss:
            dist.reduce(ce_batch_loss, 0)
            ce_batch_loss.div_(get_world_size())
            dist.reduce(full_batch_loss, 0)
            full_batch_loss.div_(get_world_size())
            dist.reduce(reference_batch_loss, 0)
            reference_batch_loss.div_(get_world_size())
            if z_batch_loss is not None:
                dist.reduce(z_batch_loss, 0)
                z_batch_loss.div_(get_world_size())

        # Clip gradient norms and collect param/gradient/optim metrics.
        should_log_optim_metrics_this_step = self.should_log_optim_metrics_this_step()
        optim_metrics = self.optim_step(should_log_metrics=should_log_optim_metrics_this_step)

        # Optim step for references
        # TODO: log these metrics?
        if self.cfg.update_reference:
            for ref_model in self.reference_models:
                ref_optim_metrics = ref_model.optim_step(should_log_metrics=should_log_optim_metrics_this_step)
                ref_model.global_step += 1  # For scheduler

        # Collect metrics and check for NaN loss.
        # NOTE: this involves a bunch of host-device syncs so we wait until the last moment to do this.
        if torch.isnan(ce_batch_loss):
            raise ValueError("nan loss encountered")
        if z_batch_loss is not None and torch.isnan(z_batch_loss):
            raise ValueError("nan loss encountered")
        for key, value in optim_metrics.items():
            metrics[f"optim/{key}"] = value.item()
        self.cur_train_loss = ce_batch_loss.item()
        self.min_train_loss = min(self.min_train_loss, self.cur_train_loss)
        metrics["train/CrossEntropyLoss"] = self.cur_train_loss
        metrics["train/Perplexity"] = math.exp(self.cur_train_loss)
        metrics["train/ReferenceLoss"] = reference_batch_loss.item()
        metrics["train/FullLoss"] = full_batch_loss.item()
        if z_batch_loss is not None:
            metrics["train/ZLoss"] = z_batch_loss.item()
        if multi_reference_freq != []:
            for i, freq in enumerate(multi_reference_freq):
                metrics[f"train/MultiReferenceFreq_{i}"] = freq.item()

        # Maybe collect post-step optimizer-specific metrics.
        if should_log_optim_metrics_this_step:
            optim_metrics = self.optim.get_post_step_metrics(
                self.fsdp_model, process_group=self.fsdp_model.process_group
            )
            for key, value in optim_metrics.items():
                metrics[f"optim/{key}"] = value.item()

        # Write learner scores to disk using dict writer
        if self.cfg.collect_learner_score:
            assert not self.cfg.sft
            self.learner_score_writer.write(batch_data["index"], batch_data["learner_score"])
        return metrics

    def fit(self):

        if self.cfg.collect_learner_score:
            self.learner_score_writer = DictMemmapWriter(
                Path(self.cfg.save_folder) / "learn_score",
                memmap_dtype=np.float32,
                seq_len=self.cfg.model.max_sequence_length - 1,  # Losses are one token shorter that seq_len
            )
        if self.cfg.collect_reference_score:
            self.reference_score_writer = DictMemmapWriter(
                Path(self.cfg.save_folder) / "ref_score",
                memmap_dtype=np.float32,
                seq_len=self.cfg.model.max_sequence_length - 1,  # Losses are one token shorter that seq_len
            )

        super().fit()

        if self.cfg.collect_learner_score:
            self.learner_score_writer.close()
        if self.cfg.collect_reference_score:
            self.reference_score_writer.close()

    def score_step(self, batch: Dict[str, Any]) -> Dict[str, float]:
        metrics: Dict[str, float] = {}
        batch_data: Dict[str, torch.Tensor] = {}

        # Write data-indices to file.
        if self.indices_file is not None and "index" in batch:
            indices = "\t".join(str(int(i)) for i in batch["index"])
            self.indices_file.write(f"{self.global_step}\t{indices}\n")

        # Move tensors to the right device.
        batch = move_to_device(batch, self.device)
        micro_batches = self.split_batch(batch)
        batch_scores = []
        reference_batch_loss = torch.tensor(0.0, device=self.device)
        for micro_batch in micro_batches:
            with torch.autocast("cuda", enabled=True, dtype=self.cfg.autocast_precision):
                assert self.cfg.granularity == "token"
                reference_loss, _ = self.model_forward(micro_batch, loss_reduction="none", return_logits=False)
                batch_scores.append(reference_loss)
                reference_batch_loss += reference_loss.mean().detach() / len(micro_batches)
        batch_scores = torch.concatenate(batch_scores, dim=0)
        batch_data["ref_score"] = batch_scores.detach().cpu()
        batch_data["index"] = batch["index"].detach().cpu()
        metrics["train/ReferenceLoss"] = reference_batch_loss.item()
        return metrics, batch_data

    def score_reference(self):
        self._start_time = time.time()

        self.fsdp_model.eval()

        # Initializer dataset writer
        data_writer = DictMemmapWriter(
            Path(self.cfg.save_folder) / "ref_score",
            memmap_dtype=np.float32,
            seq_len=self.cfg.model.max_sequence_length - 1,  # Losses are one token shorter that seq_len
        )

        # Initialize monitors.
        assert self.cfg.device_train_batch_size is not None
        speed_monitor = SpeedMonitor(self.cfg.speed_monitor)

        # Log system metrics at the start of training.
        sys_metrics = self.system_metrics()
        if sys_metrics:
            self.log_metrics_to_console("Pre-train system metrics", sys_metrics)
            if wandb.run is not None:
                wandb.log(sys_metrics, step=0)

        # Eval
        first_batch: bool = True
        cancel_initiated: bool = False
        stop_at: Optional[int] = self.cfg.stop_at

        if stop_at is None and self.max_epochs == 1:
            stop_at = self.max_steps

        # Maybe fast forward data for parallelism
        if self.cfg.data_start_step is not None:
            self.dataset.start_index = int(self.cfg.data_start_step) * self.cfg.global_train_batch_size

        for batch in self.train_loader:
            # Bookkeeping.
            # NOTE: To track the global batch size / number of tokens per batch we make the assumption that all
            # batches see the same number of tokens, which should be the case for language model pre-training
            # (at least when drop_last=True).
            # Alternatively we'd have to use a distributed all reduce over seq_len here, but I don't want that
            # overhead. So for now I'm putting these assertions here so if the assumption is violated it will
            # fail loudly.
            batch_size, seq_len = batch["input_ids"].shape
            assert seq_len == self.cfg.model.max_sequence_length
            assert batch_size == self.cfg.device_train_batch_size
            global_batch_size = batch_size * get_world_size()  # assumes batch size equal across ranks
            self.global_step += 1
            self.global_train_examples_seen_this_epoch += global_batch_size
            self.global_train_tokens_seen += global_batch_size * seq_len
            speed_monitor.batch_start(
                self.global_train_tokens_seen,
                batch_size * seq_len,  # num tokens in batch for this device
                # We start monitoring speed after the first batch since the first
                # batch might be an outlier due to compiling and other initialization overhead.
                record=not first_batch,
            )

            should_log_this_step = self.should_log_this_step()

            # Run on batch
            with torch.no_grad():
                metrics, batch_data = self.score_step(batch)

            # Write outputs
            idx = batch_data["index"]
            ref_score = batch_data["ref_score"]
            data_writer.write(idx, ref_score)

            # Maybe collect other metrics.
            if should_log_this_step:
                # Speed metrics.
                metrics.update(speed_monitor.check())
                # System metrics.
                metrics.update(self.system_metrics())

            # Log metrics to console.
            if self.global_step % self.cfg.console_log_interval == 0:
                self.log_metrics_to_console(f"[step={self.global_step}/{self.max_steps}]", metrics)

            # Log metrics to W&B.
            if (
                wandb.run is not None
                and self.cfg.wandb is not None
                and self.global_step % self.cfg.wandb.log_interval == 0
            ):
                wandb.log(metrics, step=self.global_step)

            # Check if/when run should be canceled.
            if not cancel_initiated and self.global_step % self.cfg.canceled_check_interval == 0:
                cancel_initiated, extra_steps = self.check_if_cancelled()
                if cancel_initiated:
                    stop_at = (
                        self.global_step + extra_steps
                        if stop_at is None
                        else min(self.global_step + extra_steps, stop_at)
                    )

            # End of batch.
            first_batch = False
            if stop_at is not None and self.global_step >= stop_at:
                break

        # Close writer
        data_writer.close()

    def eval(self, eval_gen: bool = False) -> Dict[str, Any]:
        # Zero gradients and set model to 'eval' mode.
        self.optim.zero_grad(set_to_none=True)
        self.fsdp_model.eval()

        eval_metrics = {}
        if eval_gen:
            log.info("Running evaluation for generative tasks and skip ICL multi-choice tasks...")
            evaluators = [evaluator for evaluator in self.evaluators if evaluator.label in generative_tasks]
        else:
            log.info("Running evaluation for ICL multi-choice tasks and skip generative tasks...")
            evaluators = [evaluator for evaluator in self.evaluators if evaluator.label not in generative_tasks]

        for evaluator in evaluators:
            log.info(f"Running evaluation for '{evaluator.label}'...")

            # Reset metrics.
            evaluator.reset_metrics()

            # Initialize data loader iterator.
            eval_batches = iter(evaluator.eval_loader)

            # Adjust how many batches to evaluate on.
            num_eval_batches = (
                evaluator.subset_num_batches
                if evaluator.subset_num_batches is not None
                else self.cfg.eval_subset_num_batches
            )
            if num_eval_batches > 0:
                num_eval_batches = min(num_eval_batches, len(evaluator.eval_loader))
                eval_batches = islice(eval_batches, num_eval_batches)

            # Run model over batches.
            for eval_step, eval_batch in enumerate(eval_batches):
                if eval_gen:
                    self.eval_gen(eval_batch, evaluator)
                else:
                    self.eval_step(eval_batch, evaluator)

                # Log to console.
                if eval_step + 1 == num_eval_batches or (eval_step + 1) % self.cfg.console_log_interval == 0:
                    log.info(f"[eval_step={eval_step + 1}/{num_eval_batches}]")

            # Get final metrics.
            metrics = evaluator.compute_gen_metrics() if eval_gen else evaluator.compute_metrics()
            eval_metrics.update(metrics)
            self.log_metrics_to_console(f"{evaluator.label}", metrics)

            del eval_batches

        return eval_metrics

    # NOTE: only called at the end of training by loading from ckpt, and I have NOT ported this change yet
    def eval_gen(self, batch: Dict[str, Any], evaluator: Evaluator) -> None:
        """
        batch - ['input_ids'] contain natural language instructions, problem statement etc, ['solution'] contain executable code with canonical solutions
                ['code_setup'] contain code setup like packages imported for the problem, ['test_cases'] contain test cases for the problem,
        """

        # Move tensors to the right device.
        batch = move_to_device(batch, self.device)
        with torch.autocast("cuda", enabled=True, dtype=self.cfg.autocast_precision):
            # tokenizer = AutoTokenizer.from_pretrained(self.cfg.tokenizer.identifier)
            tokenizer = Tokenizer.from_train_config(self.cfg)
            input_ids = batch["input_ids"]
            # prompts = tokenizer.batch_decode(batch["input_ids"], skip_special_tokens=True)

            hf_model_path = self.unsharded_checkpoint_path
            eval_model = OLMoForCausalLM.from_pretrained(hf_model_path)
            eval_model.to(self.device)

            # use huggingface generate function # TODO pass@k - num_return_sequences is not working yet # num_return_sequences=max(evaluator.eval_metric.pass_at_ks)
            generated_tokens = eval_model.generate(
                input_ids=input_ids,
                attention_mask=batch.get("attention_mask"),
                max_new_tokens=400,
                do_sample=True,
                top_k=50,
                top_p=0.95,
            )

        # Update metrics
        code_str_to_be_evaluated = tokenizer.batch_decode(
            generated_tokens[:, input_ids.shape[1] :], skip_special_tokens=True
        )
        evaluator.update_gen_metrics(batch["solution_str"], code_str_to_be_evaluated)

        barrier()

    def get_labels(self, batch: Dict[str, Any]) -> torch.Tensor:
        # Labels are just input IDs shifted to the left (first item is ignored).
        labels, label_mask, attention_mask = (
            batch["input_ids"].clone(),
            batch.get("label_mask"),
            batch.get("attention_mask"),
        )
        if label_mask is not None:
            labels.masked_fill_(~label_mask, -100)
        if attention_mask is not None:
            labels.masked_fill_(attention_mask == 0.0, -100)

        # Mask out padding tokens
        pad_token_id = self.cfg.model.pad_token_id
        labels.masked_fill_(labels == pad_token_id, -100)
        return labels[..., 1:].contiguous()
