from mmengine.evaluator import BaseMetric
from evaluate import load

import torch

class bert(BaseMetric):

    default_prefix = 'bert'  # set default_prefix
    bertscore = load("bertscore")

    def process(self, data_batch, data_samples):
        outputs = [''.join(data_samples)]
        labels = data_batch['data_samples']['references']
        self.results.append({
            'outputs': outputs,
            'labels': labels,
        })

    def compute_metrics(self, results):
        outputs_list = []
        labels_list = []
        for utt_result in results:
            outputs_list = outputs_list + utt_result['outputs']
            labels_list = labels_list + utt_result['labels']
        eval_result = self.bertscore.compute(predictions=outputs_list, references=labels_list, lang="en")
        return dict(
            precision=sum(eval_result['precision']) / len(eval_result['precision']),
            recall=sum(eval_result['recall']) / len(eval_result['recall']),
            f1=sum(eval_result['f1']) / len(eval_result['f1']),
        )
