from typing import Any, Callable
from datasets import load_dataset

class ForgetPretrain:
    def __init__(
        self, name, split, 
        template_func: callable, forget_key: str = None, question_key: str = None, answer_key: str = None
    ):
        self.data = load_dataset(
            name, split
        )['train']
        self.split = split
        self.template_func = template_func
        self.forget_key = forget_key
        self.question_key = question_key
        self.answer_key = answer_key

    def __len__(self):
        return len(self.data)

    def __iter__(self):
        return iter([self[i] for i in range(len(self))])

    def __getitem__(self, index):
        if self.forget_key is not None:
            return "[INST] " + self.data[index][self.forget_key] + " [/INST]"
        else:
            return self.apply_template(self.data[index], self.template_func)

    def apply_template(self, item, template_func : Callable):
        if isinstance(item[self.answer_key], str):
            return template_func(
                question=item[self.question_key],
                answer=item[self.answer_key],
            )
        else:
            return [template_func(
                question=item[self.question_key],
                answer=ans,
            ) for ans in item[self.answer_key]]