from typing import Sequence, List

from mmengine.evaluator import BaseMetric


class loss(BaseMetric):
    """ Accuracy Evaluator

    Default prefix: loss

    Metrics:
        - loss (float): classification loss
    """

    default_prefix = 'loss'  # set default_prefix

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def process(self, data_batch: Sequence[dict], data_samples: Sequence[dict]):
        """Process one batch of data and predictions. The processed
        Results should be stored in `self.results`, which will be used
        to compute the metrics when all batches have been processed.

        Args:
            data_batch (Sequence[Tuple[Any, dict]]): A batch of data
                from the dataloader.
            data_samples (Sequence[dict]): A batch of outputs from
                the model.
        """
        # fetch classification prediction results and category labels
        result = {
            'loss': data_samples[0]['loss'].item(),
        }

        # store the results of the current batch into self.results
        self.results.append(result)

    def compute_metrics(self, results: List):
        """Compute the metrics from processed results.

        Args:
            results (dict): The processed results of each batch.

        Returns:
            Dict: The computed metrics. The keys are the names of the metrics,
            and the values are corresponding results.
        """

        # calculate the classification accuracy
        loss_sum = sum([res['loss'] for res in results])

        # return evaluation metric results
        return {'loss': loss_sum / len(results)}
