from typing import Sequence, List

from mmengine.evaluator import BaseMetric

from xtuner.registry import BUILDER


class WER(BaseMetric):
    """ WER Evaluator

    Default prefix: WER

    Metrics:
        - WER (float): classification WER
    """

    default_prefix = 'WER'  # set default_prefix

    def __init__(self, tokenizer, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.tokenizer = BUILDER.build(tokenizer)

    def process(self, data_batch: Sequence[dict], data_samples: Sequence[dict]):
        """Process one batch of data and predictions. The processed
        Results should be stored in `self.results`, which will be used
        to compute the metrics when all batches have been processed.

        Args:
            data_batch (Sequence[Tuple[Any, dict]]): A batch of data
                from the dataloader.
            data_samples (Sequence[dict]): A batch of outputs from
                the model.
        """
        # fetch classification prediction results and category labels
        import pdb; pdb.set_trace()
        result = {
            'pred': ''.join(data_samples[4:-4]),
            'split': data_batch['data']['split']
        }

        # store the results of the current batch into self.results
        self.results.append(result)

    def compute_metrics(self, results: List):
        """Compute the metrics from processed results.

        Args:
            results (dict): The processed results of each batch.

        Returns:
            Dict: The computed metrics. The keys are the names of the metrics,
            and the values are corresponding results.
        """

        # # aggregate the classification prediction results and category labels for all samples
        # preds = np.concatenate([res['pred'] for res in results])
        # gts = np.concatenate([res['gt'] for res in results])

        # # calculate the classification accuracy
        # acc = (preds == gts).sum() / preds.size

        # return evaluation metric results
        return {'accuracy': 0}
