
from typing import List

import torch
from transformers import LogitsProcessor


class EarlyStopLogitsProcessor(LogitsProcessor):
    def __init__(self, early_stop_token_string_list: List = None, tokenizer=None, forced_eos_token_id=2):
        self.early_stop_token_string_list = early_stop_token_string_list
        self.tokenizer = tokenizer
        self.forced_eos_token_id = forced_eos_token_id
        super().__init__()

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:

        if self.early_stop_token_string_list is not None:
            for early_stop_string in self.early_stop_token_string_list:
                early_stop_token = self.tokenizer(early_stop_string, return_tensors="pt",
                                                  add_special_tokens=False).input_ids.tolist()[0][1:]
                last_token_count = len(early_stop_token)

                last_token_ids = input_ids.tolist()[0][-last_token_count:]
                if last_token_ids == early_stop_token:
                    scores[:, self.forced_eos_token_id] = float('inf')

        return scores
