import torch as t
from acdc.greaterthan.utils import get_all_greaterthan_things, get_greaterthan_true_edges
from acdc.TLACDCEdge import TorchIndex

from hypo_interp.tasks.mech_interp_task import MechInterpTask
from hypo_interp.types_ import Circuit


class GreaterThanTask(MechInterpTask):
    """
    Greater than task from:
    # How does GPT-2 compute greater-than?: Interpreting mathematical abilities in a pre-trained language model
    https://openreview.net/pdf?id=p4PckNQR8k
    """

    def __init__(
        self,
        zero_ablation: bool = False,
        device: str = "cuda",
        num_examples: int = 100,
    ):
        """
        num_examples:
            Number of examples to use in the dataset.
        """
        super().__init__(zero_ablation=zero_ablation, device=device)

        all_greaterthan_things = get_all_greaterthan_things(
            num_examples=num_examples,
            metric_name="greaterthan",  # The metric of interest
            device=device,
        )

        self._validation_metric = all_greaterthan_things.validation_metric
        self._base_dataset = all_greaterthan_things.validation_data
        self._ablate_dataset = all_greaterthan_things.validation_patch_data
        self._experiment = self._make_experiment(
            base_dataset=self._base_dataset,
            ablate_dataset=self._ablate_dataset,
            model=all_greaterthan_things.tl_model,
            validation_metric=self._validation_metric,
            zero_ablation=self._zero_ablation,
            use_pos_embed=self.use_pos_embed,
        )
        self._validate_attributes()

    def score(self, per_prompt: bool = False) -> t.Tensor:
        """
        Returns the score of the current circuit.
        per_prompt: bool
            If True, returns the score per prompt, otherwise returns the mean score.
        """
        logits = self._experiment.model(self._base_dataset, return_type="logits")

        scores = self._validation_metric(
            logits,
            return_one_element=not per_prompt,
        )
        return scores, logits

    @property
    def _canonical_circuit(self) -> Circuit:
        circuit: Circuit = list(
            get_greaterthan_true_edges(self._experiment.model).items()
        )
        circuit = [
            ((c[0][0], TorchIndex(c[0][1]), c[0][2], TorchIndex(c[0][3])), c[1])
            for c in circuit
        ]
        return circuit
