import re
import copy
import random
from typing import Dict, Sequence

import torch
import transformers

from src.common.templates import (
    PROMPT_DICT,
    INFERENCE_PATTERN,
    DATA_TYPE_DICT,
    CHAT_PROMPT_DICT,
)
from src.utils.file_utils import load_jsonl
from src.utils.data_utils import find_nth
from src.evol.data_utils import load_few_shot_demos

IGNORE_INDEX = -100


def _tokenize_fn(
    strings: Sequence[str],
    tokenizer: transformers.PreTrainedTokenizer,
    return_offsets_mapping=False,
    add_eos_token=True,
    add_special_tokens=True,
) -> Dict:
    """Tokenize a list of strings."""
    tokenized_list = [
        tokenizer(
            text,
            return_tensors="pt",
            padding="longest",
            max_length=tokenizer.model_max_length,
            truncation=True,
            add_special_tokens=add_special_tokens,
            return_offsets_mapping=return_offsets_mapping,
        )
        for text in strings
    ]
    print("tokenizer.model_max_length", tokenizer.model_max_length)
    input_ids = labels = [
        (
            torch.cat([tokenized.input_ids[0], torch.tensor([tokenizer.eos_token_id])])
            if add_eos_token
            else tokenized.input_ids[0]
        )
        for tokenized in tokenized_list
    ]
    input_ids_lens = labels_lens = [
        tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
        for tokenized in tokenized_list
    ]
    attention_mask = [
        torch.cat([tokenized.attention_mask[0], torch.tensor([1])])
        for tokenized in tokenized_list
    ]
    ret = dict(
        input_ids=input_ids,
        labels=labels,
        input_ids_lens=input_ids_lens,
        labels_lens=labels_lens,
        attention_mask=attention_mask,
    )
    if return_offsets_mapping:
        offset_mapping = [
            tokenized["offset_mapping"][0].tolist() for tokenized in tokenized_list
        ]
        ret["offset_mapping"] = offset_mapping
    return ret


def extract_solution(target):
    if "solution()" in target:
        return target.split('"""')[-1]
    if "The answer is" in target:
        pattern = r"^>.*?$"
        return re.sub(pattern, "", target, flags=re.MULTILINE)
    return target


def find_sub_list(l, sl):
    sll = len(sl)
    for ind in (i for i, e in enumerate(l) if e == sl[0]):
        if l[ind : ind + sll] == sl:
            return ind


def preprocess(
    sources: Sequence[str],
    targets: Sequence[str],
    tokenizer: transformers.PreTrainedTokenizer,
    data_args,
    model_args,
) -> Dict:
    """Preprocess the data by tokenizing."""
    examples = [s + t for s, t in zip(sources, targets)]
    solutions = [extract_solution(t) for t in targets]
    examples_tokenized = _tokenize_fn(examples, tokenizer)
    sources_tokenized = _tokenize_fn(sources, tokenizer, add_eos_token=False)
    solution_tokenized = _tokenize_fn(solutions, tokenizer, add_special_tokens=False)
    source_lens = sources_tokenized["input_ids_lens"]
    input_ids = examples_tokenized["input_ids"]
    labels = copy.deepcopy(input_ids)
    if data_args.apply_partial_tgt_loss:
        for label, input_id, target_id in zip(
            labels, input_ids, solution_tokenized["input_ids"]
        ):
            try:
                ts_idx = find_sub_list(input_id.tolist(), target_id[1:].tolist())
            except:
                print("Warning", label)
                ts_idx = label.shape[-1] - 1
            label[:ts_idx] = IGNORE_INDEX
    elif not data_args.apply_src_loss:
        for i, (label, source_len) in enumerate(zip(labels, source_lens)):
            if source_len <= 1:
                source_len = 0
            # assert source_len + len(solution_tokenized["input_ids"]) == len(label), (
            #     source_len,
            #     len(solution_tokenized["input_ids"]),
            #     len(label),
            # )
            label[:source_len] = IGNORE_INDEX
    # if model_args.copy_question:
    if data_args.copy_question:
        assert data_args.apply_src_loss
        for i in range(len(input_ids)):
            labels[i] = torch.cat(
                [labels[i][: source_lens[i] - 1], labels[i][source_lens[i] :]], dim=0
            )
            input_ids[i] = input_ids[i][:-1]
    return dict(input_ids=input_ids, labels=labels)


def process_dataset(examples, tokenizer, data_args, model_args):
    prompt_input = PROMPT_DICT[data_args.inst_type]["prompt_input"]
    if "instruction" in examples:
        sources = [
            prompt_input.format_map(dict(instruction=instruction, input=input))
            for instruction, input in zip(examples["instruction"], examples["input"])
        ]
    else:
        sources = [
            prompt_input.format_map(dict(instruction=input)) if len(input) > 0 else ""
            for input in examples["input"]
        ]
    targets = [f"{output}" for output in examples["output"]]
    data_dict = preprocess(sources, targets, tokenizer, data_args, model_args)
    return data_dict


def process_dataset_mc(examples):
    inputs, questions, outputs = (
        examples["input"],
        examples["question"],
        examples["output"],
    )
    sources = []
    targets = []
    expanded_questions = []
    for cur_input, cur_question, cur_outputs in zip(inputs, questions, outputs):
        # source = prompt_input.format_map(dict(instruction=cur_input))
        for output in cur_outputs:
            sources.append(cur_input)
            expanded_questions.append(cur_question)
            targets.append(output)
    return dict(input=sources, question=expanded_questions, output=targets)


def process_dataset_cringe(examples, tokenizer, data_args, model_args):
    data_dict = process_dataset(examples, tokenizer, data_args, model_args)
    data_dict["classifier_labels"] = examples["classifier_labels"]
    return data_dict


def process_dataset_cringe_gsm(examples, tokenizer, data_args, model_args):
    def _tokenize_fn_gsm(
        strings: Sequence[str],
        sources,
        string_labels,
        tokenizer: transformers.PreTrainedTokenizer,
    ) -> Dict:
        """Tokenize a list of strings."""

        # def update_tokenize(text, clabel=1, tok_begin_idx=1, is_end=False):
        #     nonlocal cur_input_ids
        #     nonlocal cur_attention_mask
        #     nonlocal classifier_label
        #     tokenized = tokenizer(
        #         text.strip(),
        #         return_tensors="pt",
        #         padding="longest",
        #         max_length=tokenizer.model_max_length,
        #         truncation=True,
        #     )
        #     new_input_ids = tokenized.input_ids[0][tok_begin_idx:]
        #     new_attention_mask = tokenized.attention_mask[0][tok_begin_idx:]
        #     cur_input_ids = torch.cat([cur_input_ids, new_input_ids], dim=0)
        #     cur_attention_mask = torch.cat(
        #         [cur_attention_mask, new_attention_mask], dim=0
        #     )
        #     # cur_input_ids.extend(tokenized.input_ids[0])
        #     # cur_attention_mask.extend(tokenized.attention_mask[0])
        #     classifier_label.extend([clabel for _ in range(len(new_input_ids))])

        # # tokenized_list = [
        # #     tokenizer(
        # #         text,
        # #         return_tensors="pt",
        # #         padding="longest",
        # #         max_length=tokenizer.model_max_length,
        # #         truncation=True,
        # #     )
        # #     for text in strings
        # # ]
        # # classifier_labels = []
        # # for tokenized, string_label in zip(tokenized_list, string_labels):
        # #     word_ids = tokenized.word_ids()
        # #     print(word_ids)
        # #     previous_word_idx = None
        # #     classifier_label = []
        # #     for word_idx in word_ids:
        # #         if word_idx is None:
        # #             classifier_label.append(1)
        # #         else:
        # #             classifier_label.append(string_label[word_idx])
        # #     classifier_labels.append(classifier_label)
        # input_ids = []
        # attention_mask = []
        # classifier_labels = []
        # for text, source, string_label in zip(strings, sources, string_labels):
        #     string_label = [
        #         (len(source) + pos[0], len(source) + pos[1]) for pos in string_label
        #     ]
        #     cur_input_ids = torch.tensor([], dtype=torch.int64)
        #     cur_attention_mask = torch.tensor([], dtype=torch.int64)
        #     classifier_label = []
        #     last_pos = 0
        #     for pos in string_label:
        #         if last_pos != pos[0]:
        #             update_tokenize(
        #                 text[last_pos : pos[0]],
        #                 clabel=1,
        #                 tok_begin_idx=1 if last_pos != 0 else 0,
        #             )
        #             last_pos = pos[0]
        #         update_tokenize(
        #             text[last_pos : pos[1]],
        #             clabel=0,
        #             tok_begin_idx=1 if last_pos != 0 else 0,
        #         )
        #         last_pos = pos[1]
        #     if last_pos != len(text):
        #         update_tokenize(text[last_pos:], clabel=1)
        #     cur_input_ids = torch.cat(
        #         [cur_input_ids, torch.tensor([tokenizer.eos_token_id])], dim=0
        #     )
        #     cur_attention_mask = torch.cat(
        #         [cur_attention_mask, torch.tensor([1])], dim=0
        #     )
        #     classifier_label.append(1)

        #     input_ids.append(cur_input_ids)
        #     attention_mask.append(cur_attention_mask)
        #     classifier_labels.append(classifier_label)

        # labels = input_ids
        # input_ids_lens = labels_lens = [
        #     cur_input_ids.ne(tokenizer.pad_token_id).sum().item()
        #     for cur_input_ids in input_ids
        # ]

        tokenized_list = [
            tokenizer(
                text,
                return_tensors="pt",
                padding="longest",
                max_length=tokenizer.model_max_length,
                truncation=True,
                return_offsets_mapping=True,
            )
            for text in strings
        ]
        classifier_labels = []
        input_ids = []
        attention_mask = []
        for tokenized, source, string_label in zip(
            tokenized_list, sources, string_labels
        ):
            cur_input_ids = torch.cat(
                [tokenized.input_ids[0], torch.tensor([tokenizer.eos_token_id])], dim=0
            )
            cur_attention_mask = torch.cat(
                [tokenized.attention_mask[0], torch.tensor([1])], dim=0
            )
            input_ids.append(cur_input_ids)
            attention_mask.append(cur_attention_mask)
            classifier_label = [1 for _ in range(cur_input_ids.shape[0])]
            for token_id, string_id in enumerate(
                tokenized["offset_mapping"][0].tolist()
            ):
                if string_id == (0, 0) or string_id[0] < len(source):
                    continue
                if (
                    string_label[string_id[0] - len(source)] == 0
                    or string_label[string_id[1] - len(source) - 1] == 0
                ):
                    classifier_label[token_id] = 0
            classifier_labels.append(classifier_label)

        labels = [cur_input_ids for cur_input_ids in input_ids]
        input_ids_lens = labels_lens = [
            cur_input_ids.ne(tokenizer.pad_token_id).sum().item()
            for cur_input_ids in input_ids
        ]

        return dict(
            input_ids=input_ids,
            labels=labels,
            input_ids_lens=input_ids_lens,
            labels_lens=labels_lens,
            attention_mask=attention_mask,
            classifier_labels=classifier_labels,
        )

    def get_math_labels(targets):
        string_labels = []
        # for target in targets:
        #     string_label = []
        #     expr_start = None
        #     for i, c in enumerate(target):
        #         if pattern.match(c) or (
        #             (c == "." or c == ",")
        #             and (i + 1 < len(target) and pattern.match(target[i + 1]))
        #         ):
        #             if expr_start is None:
        #                 expr_start = i
        #         else:
        #             if expr_start is not None:
        #                 string_label.append((expr_start, i))
        #             expr_start = None
        #     if expr_start is not None:
        #         string_label.append((expr_start, len(target)))
        #     string_labels.append(string_label)
        # return string_labels
        for target in targets:
            string_label = []
            for i, c in enumerate(target):
                if pattern.match(c) or (
                    (c == "." or c == ",")
                    and (i + 1 < len(target) and pattern.match(target[i + 1]))
                ):
                    string_label.append(0)
                else:
                    string_label.append(1)
            string_labels.append(string_label)
        return string_labels

    def preprocess(
        sources: Sequence[str],
        targets: Sequence[str],
        tokenizer: transformers.PreTrainedTokenizer,
        apply_src_loss,
    ) -> Dict:
        """Preprocess the data by tokenizing."""
        math_labels = get_math_labels(targets)
        examples = [s + t for s, t in zip(sources, targets)]
        examples_tokenized = _tokenize_fn_gsm(examples, sources, math_labels, tokenizer)
        sources_tokenized = _tokenize_fn(sources, tokenizer)
        input_ids = examples_tokenized["input_ids"]
        labels = copy.deepcopy(input_ids)
        if apply_src_loss:
            for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
                if source_len <= 1:
                    source_len = 0
                label[:source_len] = IGNORE_INDEX
        return dict(
            input_ids=input_ids,
            labels=labels,
            classifier_labels=examples_tokenized["classifier_labels"],
        )

    pattern = re.compile(r"\d+|[\+\-\*\/\%\^=]")
    prompt_input = PROMPT_DICT[data_args.inst_type]["prompt_input"]
    sources = [
        prompt_input.format_map(dict(instruction=input)) if len(input) > 0 else ""
        for input in examples["input"]
    ]
    # targets = [f"{output}{tokenizer.eos_token}" for output in examples["output"]]
    targets = [f"{output}" for output in examples["output"]]
    data_dict = preprocess(sources, targets, tokenizer, data_args.apply_src_loss)
    return data_dict


def process_dataset_mle_aug(examples, tokenizer, data_args, model_args):
    data_types = DATA_TYPE_DICT[model_args.data_type]
    num_examples = len(examples["question"])
    format_samples = {"input": [], "output": []}
    for t in data_types:
        for i in range(len(examples["question"])):
            question = INFERENCE_PATTERN[t][0].replace(
                "{question}", examples["question"][i]
            )
            output = examples[t][i]
            format_samples["input"].append(question)
            format_samples["output"].append(output)
    data_dict = process_dataset(format_samples, tokenizer, data_args, model_args)
    # group
    data_type_size = len(data_types)
    features = {}
    for key in data_dict:
        features[key] = [
            [data_dict[key][i + num_examples * j] for j in range(data_type_size)]
            for i in range(num_examples)
        ]
    return features


def format_sample(question, data_type):
    question = INFERENCE_PATTERN[data_type][0].replace("{question}", question)
    return question


def process_dataset_cl(examples, tokenizer, data_args, model_args):
    def process_cross_modal():
        positives = {"input": [], "output": []}
        negatives = {"input": [], "output": []}
        neg_reward = []
        for i in range(len(examples["question"])):
            for t in data_types:
                if examples[t][i] is None:
                    # if examples[f"{t}_positive"][i] is not None:
                    #     output = examples[f"{t}_positive"][i]
                    #     question = format_sample(examples["question"][i], t)
                    # else:
                    question = format_sample(examples["question"][i], "cot")
                    output = examples["cot"][i]
                else:
                    question = format_sample(examples["question"][i], t)
                    output = examples[t][i]
                if data_args.remove_pal_question and t == "pal":
                    output = re.sub(r"    \"\"\".*?\"\"\"\n", "", output)
                positives["input"].append(question)
                positives["output"].append(output)
                assert output is not None
                if examples[f"{t}_negative"][i] is not None:
                    question = format_sample(examples["question"][i], t)
                    output = examples[f"{t}_negative"][i]
                    neg_reward.append(0)  # TODO fine-grained
                else:
                    while True:
                        j = random.choice(range(len(examples["question"])))
                        if i == j or examples[f"{t}_positive"][j] is None:
                            continue
                        question = format_sample(examples["question"][j], t)
                        output = examples[f"{t}_positive"][j]
                        neg_reward.append(1)  # TODO fine-grained
                        break
                if data_args.remove_pal_question and t == "pal":
                    output = re.sub(r"    \"\"\".*?\"\"\"\n", "", output)
                negatives["input"].append(question)
                negatives["output"].append(output)
                assert output is not None
        print("positive")
        for i, o in zip(positives["input"][:50], positives["output"][:50]):
            print(i + o)
            print()
        print("negatives")
        for i, o in zip(negatives["input"][:50], negatives["output"][:50]):
            print(i + o)
            print()
        data_type_size = len(data_types)
        num_examples = len(positives["input"]) // data_type_size
        positives = process_dataset(positives, tokenizer, data_args, model_args)
        negatives = process_dataset(negatives, tokenizer, data_args, model_args)
        features = {}
        for key in positives:
            features[key] = [
                [positives[key][i * data_type_size + j] for j in range(data_type_size)]
                for i in range(num_examples)
            ]

            if not model_args.no_hard_neg:
                features[f"negative_{key}"] = [
                    [
                        negatives[key][i * data_type_size + j]
                        for j in range(data_type_size)
                    ]
                    for i in range(num_examples)
                ]
        if not model_args.no_hard_neg:
            features["reward"] = [
                [neg_reward[i * data_type_size + j] for j in range(data_type_size)]
                for i in range(num_examples)
            ]
            assert len(features["input_ids"]) == len(features["reward"])
        return features

    def process_single_modal():
        positives = {"input": [], "output": []}
        negatives = {"input": [], "output": []}
        neg_reward = []
        num_pos = 2 if not model_args.no_add_pos else 1
        for i in range(len(examples["question"])):
            for t in data_types:
                if examples[t][i] is None:
                    if examples[f"{t}_positive"][i] is None:
                        continue
                    question = format_sample(examples["question"][i], t)
                    output = examples[f"{t}_positive"][i]
                else:
                    question = format_sample(examples["question"][i], t)
                    output = examples[t][i]
                assert output is not None
                if data_args.remove_pal_question and t == "pal":
                    output = re.sub(r"    \"\"\".*?\"\"\"\n", "", output)
                positives["input"].append(question)
                positives["output"].append(output)

                if num_pos == 2:
                    if examples[f"{t}_positive"][i] is None:
                        question = format_sample(examples["question"][i], t)
                        output = examples[t][i]
                    else:
                        question = format_sample(examples["question"][i], t)
                        output = examples[f"{t}_positive"][i]
                    assert output is not None
                    if data_args.remove_pal_question and t == "pal":
                        output = re.sub(r"    \"\"\".*?\"\"\"\n", "", output)
                    positives["input"].append(question)
                    positives["output"].append(output)

                if examples[f"{t}_negative"][i] is not None:
                    question = format_sample(examples["question"][i], t)
                    output = examples[f"{t}_negative"][i]
                    neg_reward.append(0)  # TODO fine-grained
                else:
                    while True:
                        j = random.choice(range(len(examples["question"])))
                        if i == j or examples[t][j] is None:
                            continue
                        question = format_sample(examples["question"][j], t)
                        output = examples[t][j]
                        neg_reward.append(1)
                        break
                assert output is not None
                if data_args.remove_pal_question and t == "pal":
                    output = re.sub(r"    \"\"\".*?\"\"\"\n", "", output)
                negatives["input"].append(question)
                negatives["output"].append(output)
        print("positive")
        for i, o in zip(positives["input"][:50], positives["output"][:50]):
            print(i + o)
            print()
        print("negatives")
        for i, o in zip(negatives["input"][:50], negatives["output"][:50]):
            print(i + o)
            print()
        len(positives["input"])
        num_examples = len(positives["input"]) // num_pos
        positives = process_dataset(positives, tokenizer, data_args, model_args)
        negatives = process_dataset(negatives, tokenizer, data_args, model_args)
        features = {}
        for key in positives:
            features[key] = [
                [positives[key][i * num_pos + j] for j in range(num_pos)]
                for i in range(num_examples)
            ]
            if not model_args.no_hard_neg:
                features[f"negative_{key}"] = [
                    [negatives[key][i + j] for j in range(1)]
                    for i in range(num_examples)
                ]
        if not model_args.no_hard_neg:
            features["reward"] = [
                [neg_reward[i + j] for j in range(1)] for i in range(num_examples)
            ]
            assert len(features["input_ids"]) == len(features["reward"])
        return features

    # Output [[positives], [negatives]]
    data_types = DATA_TYPE_DICT[model_args.data_type]
    if len(data_types) > 1 and not model_args.single_modal:
        return process_cross_modal()
    else:
        return process_single_modal()


def process_dataset_cls(examples, tokenizer, data_args, model_args):
    features = {"input": [], "output": []}
    rewards = []
    data_types = DATA_TYPE_DICT[model_args.data_type]
    for i in range(len(examples["question"])):
        for t in data_types:
            if examples[t][i] is not None:
                question = format_sample(examples["question"][i], t)
                output = examples[t][i]
                features["input"].append(question)
                features["output"].append(output)
                rewards.append(1)
            if isinstance(examples[f"{t}_positive"][i], str):
                question = format_sample(examples["question"][i], t)
                features["input"].append(question)
                features["output"].append(examples[f"{t}_positive"][i])
                rewards.append(1)
            elif isinstance(examples[f"{t}_positive"][i], list):
                for output in examples[f"{t}_positive"][i]:
                    question = format_sample(examples["question"][i], t)
                    features["input"].append(question)
                    features["output"].append(output)
                    rewards.append(1)
            if isinstance(examples[f"{t}_negative"][i], str):
                question = format_sample(examples["question"][i], t)
                features["input"].append(question)
                features["output"].append(examples[f"{t}_negative"][i])
                rewards.append(0)
            elif isinstance(examples[f"{t}_negative"][i], list):
                for output in examples[f"{t}_negative"][i]:
                    question = format_sample(examples["question"][i], t)
                    features["input"].append(question)
                    features["output"].append(output)
                    rewards.append(0)
    features = process_dataset(features, tokenizer, data_args, model_args)
    features["reward"] = rewards
    for i, reward in enumerate(features["reward"]):
        if reward == 0:
            features["labels"][i] = [-100 for _ in range(len(features["labels"][i]))]
    return features


def process_dataset_trans(examples, tokenizer, data_args, model_args):
    prompt_input = PROMPT_DICT["no_inst"]["prompt_input"]
    # [Question]\n<wrap>[Question]<wrapp>\nCoT
    sources = [
        prompt_input.format_map(dict(instruction=e.replace("Question:", "").strip()))
        for e in examples["input"]
    ]
    questions = [
        e.replace("Question:", "").split("# solution")[0].split("Let's")[0].strip()
        for e in examples["input"]
    ]
    targets = examples["output"]
    solutions = [extract_solution(t) for t in targets]
    examples = [s + t for s, t in zip(sources, targets)]
    question_starts = [find_nth(e, q, 2) for q, e in zip(questions, examples)]
    target_starts = [find_nth(e, q, 1) for q, e in zip(solutions, examples)]
    question_ends = [qs + len(q) for qs, q in zip(question_starts, questions)]
    examples_tokenized = _tokenize_fn(examples, tokenizer, return_offsets_mapping=True)
    sources_tokenized = _tokenize_fn(sources, tokenizer, add_eos_token=False)
    questions_tokenized = _tokenize_fn(
        questions, tokenizer, return_offsets_mapping=False, add_eos_token=False
    )
    input_ids = examples_tokenized["input_ids"]
    max_len = max([len(input_id) for input_id in input_ids])
    labels = copy.deepcopy(input_ids)
    qstart_ids, qend_ids = [], []
    ts_token_lens = []
    for offset_mapping, qi, q, qs, qe, ts in zip(
        examples_tokenized["offset_mapping"],
        questions_tokenized["input_ids"],
        questions,
        question_starts,
        question_ends,
        target_starts,
    ):
        qstart_id, qend_id = None, None
        ts_token_len = None
        for token_id, string_id in enumerate(offset_mapping):
            if string_id == (0, 0):
                continue
            # Consider padding side
            if qs >= string_id[0] and qs < string_id[1]:
                qstart_id = (
                    token_id
                    if tokenizer.padding_side == "right"
                    else max_len - len(offset_mapping) + token_id
                )
            if qe >= string_id[0] and qe < string_id[1]:
                qend_id = (
                    token_id
                    if tokenizer.padding_side == "right"
                    else max_len - len(offset_mapping) + token_id
                )
            if ts >= string_id[0] and ts < string_id[1]:
                ts_token_len = token_id
        if q[0].isnumeric():
            qstart_id -= 1
        qstart_ids.append(qstart_id)
        qend_ids.append(qend_id)
        ts_token_lens.append(ts_token_len)
        assert qend_id - qstart_id == len(qi) - 1

    if data_args.apply_partial_tgt_loss:
        for label, ts_token_len in zip(labels, ts_token_lens):
            if ts_token_len <= 1:
                ts_token_len = 0
            label[:ts_token_len] = IGNORE_INDEX
    elif not data_args.apply_src_loss:
        for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
            if source_len <= 1:
                source_len = 0
            label[:source_len] = IGNORE_INDEX
    return dict(
        input_ids=input_ids,
        labels=labels,
        question_ids=questions_tokenized["input_ids"],
        repeat_starts=qstart_ids,
        repeat_ends=qend_ids,
    )


def process_dataset_rank(examples, tokenizer, data_args, model_args):
    all_features = {"input_ids": [], "rewards": [], "labels": []}
    for i in range(len(examples["input"])):
        features = {
            "input": [examples["input"][i]] * len(examples["output"][i]),
            "output": examples["output"][i],
        }
        features = process_dataset(features, tokenizer, data_args, model_args)
        if model_args.rank_type == "cls" or model_args.rank_type == "aft":
            features["rewards"] = [
                1 if reward != 0 else 0 for reward in examples["rewards"][i]
            ]
        else:
            features["rewards"] = examples["rewards"][i]
        for key in features:
            all_features[key].append(features[key])
    return all_features


def process_dataset_rank_com(examples, tokenizer, data_args, model_args):
    all_features = {
        "input_ids": [],
        "rewards": [],
        "labels": [],
        "strategy_ids": [],
        "strategy_mask": [],
        "strategy_rewards": [],
    }
    for i in range(len(examples["input"])):
        features = {
            "input": [examples["input"][i]] * len(examples["output"][i]),
            "output": examples["output"][i],
        }
        features = process_dataset(features, tokenizer, data_args, model_args)
        if model_args.rank_type == "cls":
            features["rewards"] = [
                1 if reward != 0 else 0 for reward in examples["rewards"][i]
            ]
        else:
            features["rewards"] = examples["rewards"][i]
        strategy_sources = examples["strategy_input"][i]
        strategy_sources_tokenized = _tokenize_fn(
            strategy_sources, tokenizer, add_eos_token=False
        )
        strategy_input_ids = strategy_sources_tokenized["input_ids"]
        all_strategy_mask = []
        for j in range(len(strategy_input_ids)):
            strategy_mask = torch.ones(
                len(features["input_ids"][examples["strategy_ids"][i][j]]),
                dtype=torch.long,
            )
            strategy_mask[len(strategy_input_ids[j]) :] = 0
            all_strategy_mask.append(strategy_mask)
        all_features["strategy_ids"].append(examples["strategy_ids"][i])
        all_features["strategy_mask"].append(all_strategy_mask)
        if model_args.rank_com_type == "cls":
            all_features["strategy_rewards"].append(
                [1 if reward != 0 else 0 for reward in examples["strategy_rewards"][i]]
            )
        else:
            all_features["strategy_rewards"].append(examples["strategy_rewards"][i])
        for key in features:
            all_features[key].append(features[key])
    return all_features


def preprocess_choice(
    sources: Sequence[str],
    targets: Sequence[str],
    choice_learns: bool,
    tokenizer: transformers.PreTrainedTokenizer,
    data_args,
    model_args,
) -> Dict:
    """Preprocess the data by tokenizing."""

    def extract_choice(source):
        matches = re.search(r"## Choice\n(?:Code|Chain-of-Thought)\n\n", source)
        if matches is None:
            return ""
        return matches.group()

    examples = [s + t for s, t in zip(sources, targets)]
    choices = [extract_choice(t) for t in targets]
    examples_tokenized = _tokenize_fn(examples, tokenizer)
    sources_tokenized = _tokenize_fn(sources, tokenizer, add_eos_token=False)
    choice_tokenized = _tokenize_fn(
        choices, tokenizer, add_special_tokens=False, add_eos_token=False
    )
    choice_lens = choice_tokenized["input_ids_lens"]
    source_lens = sources_tokenized["input_ids_lens"]
    input_ids = examples_tokenized["input_ids"]
    labels = copy.deepcopy(input_ids)
    choice_labels = copy.deepcopy(labels)
    for label, choice_label, source_len, choice_len, choice_learn, choice in zip(
        labels, choice_labels, source_lens, choice_lens, choice_learns, choices
    ):
        if source_len <= 1:
            source_len = 0
        if len(choice) == 0:
            choice_len = 0
        label[: source_len + choice_len] = IGNORE_INDEX
        choice_label[:source_len] = IGNORE_INDEX
        if choice_learn:
            choice_label[source_len + choice_len :] = IGNORE_INDEX
        else:
            choice_label[:] = IGNORE_INDEX
    return dict(input_ids=input_ids, labels=labels, choice_labels=choice_labels)


def process_dataset_choice(examples, tokenizer, data_args, model_args):
    prompt_input = PROMPT_DICT[data_args.inst_type]["prompt_input"]
    sources = [
        prompt_input.format_map(dict(instruction=input)) if len(input) > 0 else ""
        for input in examples["input"]
    ]
    # targets = [f"{output}{tokenizer.eos_token}" for output in examples["output"]]
    targets = [f"{output}" for output in examples["output"]]
    data_dict = preprocess_choice(
        sources, targets, examples["choice_learn"], tokenizer, data_args, model_args
    )
    return data_dict


def preprocess_choice_cls(
    sources: Sequence[str],
    targets: Sequence[str],
    tokenizer: transformers.PreTrainedTokenizer,
    data_args,
    model_args,
) -> Dict:
    """Preprocess the data by tokenizing."""

    def extract_choice(source):
        if "## Chain-of-Thought Solution" in source:
            return (
                source.split("## Chain-of-Thought Solution")[0]
                + "## Chain-of-Thought Solution"
            )
        return source.split("## Code Solution")[0] + "## Code Solution"

    examples = [s + t for s, t in zip(sources, targets)]
    choices = [extract_choice(t) for t in targets]
    examples_tokenized = _tokenize_fn(examples, tokenizer)
    sources_tokenized = _tokenize_fn(sources, tokenizer, add_eos_token=False)
    choice_tokenized = _tokenize_fn(choices, tokenizer, add_eos_token=False)
    choice_lens = choice_tokenized["input_ids_lens"]
    source_lens = sources_tokenized["input_ids_lens"]
    input_ids = examples_tokenized["input_ids"]
    labels = copy.deepcopy(input_ids)
    input_mask = []
    for label, source_len, choice_len in zip(labels, source_lens, choice_lens):
        if source_len <= 1:
            source_len = 0
        label[:source_len] = IGNORE_INDEX
        cur_input_mask = torch.ones(len(label), dtype=torch.long)
        cur_input_mask[choice_len:] = 0
        input_mask.append(cur_input_mask)
    return dict(input_ids=input_ids, labels=labels, input_mask=input_mask)


def process_dataset_choice_cls(examples, tokenizer, data_args, model_args):
    prompt_input = PROMPT_DICT[data_args.inst_type]["prompt_input"]
    sources = [
        (
            prompt_input.format_map(dict(instruction=input))
            .replace("## Plan for Code Solution", "## Plan")
            .replace("## Plan for Chain-of-Thought Solution", "## Plan")
            if len(input) > 0
            else ""
        )
        for input in examples["input"]
    ]
    targets = [
        f"{output}".replace("## Plan for Code Solution", "## Plan").replace(
            "## Plan for Chain-of-Thought Solution", "## Plan"
        )
        for output in examples["output"]
    ]
    data_dict = preprocess_choice_cls(
        sources, targets, tokenizer, data_args, model_args
    )
    data_dict["choice_rewards"] = examples["reward"]
    # for label, reward in zip(data_dict["labels"], examples["reward"]):
    #     if reward == 0:
    #         label[:] = IGNORE_INDEX
    return data_dict


def find_all_token_indices_with_offsets(input_ids, substring, tokenizer):
    # Tokenize the original text with offsets
    decoded_input = tokenizer.decode(input_ids)
    do_modify = f"\n{substring}" in decoded_input
    if do_modify:
        substring = f"\n{substring}"
    original_tokens = tokenizer.convert_ids_to_tokens(input_ids)
    substring_tokens = tokenizer.tokenize(substring)
    if do_modify:
        substring_tokens = substring_tokens[2:]

    # Search for the substring tokens in the original token list
    index_tuples = []
    for idx in range(len(original_tokens) - len(substring_tokens) + 1):
        if original_tokens[idx : idx + len(substring_tokens)] == substring_tokens:
            # Check if the matched token sequence corresponds to a continuous substring in the original text
            index_tuples.append((idx, idx + len(substring_tokens)))

    if not index_tuples:
        print(original_tokens)
        print(substring_tokens)
        print(tokenizer.batch_decode([input_ids]), substring)
        raise ValueError("Substring token sequence not found in original text.")

    return index_tuples


def get_token_mask(
    model_args,
    input_ids,
    input_tokens,
    sample_substrings,
    offset_mappings,
    mask_value,
    output_mask_value=None,
):
    token_mask = []
    for cur_input_ids, cur_input_tokens, substrings, offset_mapping in zip(
        input_ids, input_tokens, sample_substrings, offset_mappings
    ):
        if isinstance(mask_value, bool):
            cur_token_mask = torch.zeros(len(cur_input_ids), dtype=torch.bool)
        else:
            cur_token_mask = torch.zeros(len(cur_input_ids))
        for i, substring in enumerate(substrings):
            index_tuples = get_token_positions(
                cur_input_tokens,
                substring,
                offset_mapping,
                multiple_ok=model_args.rg_multiple_ok,
            )
            for index_tuple in index_tuples:
                cur_token_mask[index_tuple[0] : index_tuple[1]] = (
                    mask_value
                    if output_mask_value is None or i != len(substrings) - 1
                    else output_mask_value
                )
        token_mask.append(cur_token_mask)
    return token_mask


def get_qk_token_mask(
    model_args,
    input_ids,
    questions,
    input_tokens,
    input_tokenized,
    is_train=False,
    solutions=None,
):
    output_factor = None
    if model_args.amplify_k_scope == "all":
        k_token_mask = [
            [model_args.amplify_factor] * len(cur_input_ids)
            for cur_input_ids in input_ids
        ]
    elif model_args.amplify_k_scope == "solution":
        k_token_mask = [[0.0] * len(cur_input_ids) for cur_input_ids in input_ids]
    else:
        if model_args.amplify_k_scope == "question" or not is_train:
            substrings = [[question] for question in questions]
        elif model_args.amplify_k_scope == "question_solution" and is_train:
            if solutions is None:
                solutions = [
                    source.split("## Solution")[-1].strip() for source in input_tokens
                ]
            substrings = [
                [question, solution] for question, solution in zip(questions, solutions)
            ]
            output_factor = model_args.amplify_output_factor
        k_token_mask = get_token_mask(
            model_args,
            input_ids,
            input_tokens,
            substrings,
            offset_mappings=input_tokenized["offset_mapping"],
            mask_value=model_args.amplify_factor,
            output_mask_value=output_factor,
        )

    if model_args.amplify_q_scope == "all":
        q_token_mask = [[1] * len(cur_input_ids) for cur_input_ids in input_ids]
    elif model_args.amplify_q_scope == "solution" and not is_train:
        q_token_mask = [[0] * len(cur_input_ids) for cur_input_ids in input_ids]
    else:
        if model_args.amplify_q_scope == "solution" and is_train:
            solutions = [
                source.split("## Solution")[-1].strip() for source in input_tokens
            ]
            # if model_args.amplify_first_step:
            #     solutions = [solution.split("\n")[0].strip() for solution in solutions]
            substrings = [[solution] for solution in solutions]
        elif (
            model_args.amplify_q_scope == "question_solution" and not is_train
        ) or model_args.amplify_q_scope == "question":
            substrings = [[question] for question in questions]
        elif model_args.amplify_q_scope == "question_solution" and is_train:
            solutions = [
                source.split("## Solution")[-1].strip() for source in input_tokens
            ]
            # if model_args.amplify_first_step:
            #     solutions = [solution.split("\n")[0].strip() for solution in solutions]
            substrings = [
                [question, solution] for question, solution in zip(questions, solutions)
            ]
        q_token_mask = get_token_mask(
            model_args,
            input_ids,
            input_tokens,
            substrings,
            offset_mappings=input_tokenized["offset_mapping"],
            mask_value=True,
        )

    return k_token_mask, q_token_mask


def construct_few_shot_prompt(
    data_args, tokenizer, examples, data_names, is_chat=False
):
    all_data_names = list(set(data_names))
    questions = examples["question"]
    sources = [
        INFERENCE_PATTERN[data_args.prompt_format][0].replace("{question}", question)
        for question in questions
    ]
    demo_dicts = load_few_shot_demos(all_data_names, data_args.data_path)
    demo_prompt_temp = PROMPT_DICT[data_args.inst_type]["prompt_input"]
    if all_data_names[0] == "bbh":
        task_names = examples["task_name"]
        demo_content_dict = {}
        if is_chat:
            sources = [
                demo_prompt_temp.format_map(
                    dict(
                        instruction=demo_dicts["bbh"][task_name].strip()
                        + "\n\nFollow the above examples and answer the question\n\nQ: "
                        + question
                        # + "\nA: Let's think step by step.\n"
                    )
                )
                # + " A: Let's think step by step.\n"
                + "\nA: Let's think step by step.\n"
                for question, task_name in zip(questions, task_names)
            ]
            return sources
        sources = [
            demo_dicts["bbh"][task_name].strip()
            + "\n\nQ: "
            + question
            + "\nA: Let's think step by step.\n"
            for question, task_name in zip(questions, task_names)
        ]
        return sources
    elif all_data_names[0] == "mmlu":
        # TODO
        pass

    # TODO pure no inst
    spliter = "\n\n"
    demo_content_dict = {}
    for data_name in all_data_names:
        demo_contents = []
        for demo in demo_dicts[data_name]:
            demo_input = INFERENCE_PATTERN[data_args.prompt_format][0].replace(
                "{question}", demo["input"]
            )
            if not demo_input.endswith("\n"):
                demo_input += "\n"
            demo_output = INFERENCE_PATTERN[data_args.prompt_format][1].replace(
                "{chain_of_thought}", demo["output"]
            )
            demo_contents.append(demo_input + demo_output)
        demo_content_dict[data_name] = spliter.join(demo_contents)
    if is_chat:
        sources = [
            demo_prompt_temp.format_map(
                dict(
                    instruction=demo_content_dict[data_name]
                    + spliter
                    + (source if source.endswith("\n") else source + "\n")
                )
            )
            + INFERENCE_PATTERN[data_args.prompt_format][1]
            .replace("{chain_of_thought}", "")
            .replace("```python\n\n```", "```python\n")
            for source, data_name in zip(sources, data_names)
        ]
        return sources
    sources = [
        demo_content_dict[data_name]
        + spliter
        + (source if source.endswith("\n") else source + "\n")
        + INFERENCE_PATTERN[data_args.prompt_format][1]
        .replace("{chain_of_thought}", "")
        .replace("```python\n\n```", "```python\n")
        for source, data_name in zip(sources, data_names)
    ]
    return sources


def tokenize_attn_amplify(tokenizer, sources, questions, model_args, no_encode=False):
    if no_encode:
        return dict(prompt=sources)
    sources_tokenized = _tokenize_fn(
        sources, tokenizer, add_eos_token=False, return_offsets_mapping=True
    )
    input_ids = sources_tokenized["input_ids"]
    source_lens = sources_tokenized["input_ids_lens"]
    if model_args.amplify_factor == 0.0:
        return dict(input_ids=input_ids)
    k_token_mask, q_token_mask = get_qk_token_mask(
        model_args, input_ids, questions, sources, sources_tokenized
    )
    data_dict = dict(
        input_ids=input_ids, k_token_mask=k_token_mask, q_token_mask=q_token_mask
    )
    if (
        model_args.amplify_total_threshold_output
        or model_args.amplify_exclude_self
        or model_args.amplify_total_threshold_upper
        or model_args.amplify_topk
    ):
        data_dict["tgt_index"] = source_lens
    return data_dict


def process_dataset_attn_amplify(
    examples, tokenizer, data_args, model_args, no_encode=False
):
    prompt_input = PROMPT_DICT[data_args.inst_type]["prompt_input"]
    questions = examples["question"]
    if model_args.few_shot:
        # Only for inference
        assert "question" in examples
        assert "type" in examples
        sources = construct_few_shot_prompt(
            data_args,
            tokenizer,
            examples,
            examples["type"],
            is_chat="chat" in model_args.model_name_or_path.lower()
            or "instruct" in model_args.model_name_or_path.lower(),
        )
    else:
        sources = [
            INFERENCE_PATTERN[data_args.prompt_format][0].replace(
                "{question}", question
            )
            for question in questions
        ]
        if "type" in examples:
            prompt_inputs = [
                (
                    prompt_input
                    if t
                    not in [
                        "csqa",
                        "csqa_dev",
                        "strategyqa",
                        "strategyqa_dev",
                        "coin_flip",
                        "last_letters",
                        "bbh",
                        "bbh_symbol",
                        "bbh_symbol_dev",
                        "arcc",
                        "arce",
                    ]
                    else prompt_input.replace(
                        "You are an expert for math problem solving.",
                        "You are an expert for reasoning.",
                    )
                )
                for t in examples["type"]
            ]
            sources = [
                pi.format_map(dict(instruction=source))
                for source, pi in zip(sources, prompt_inputs)
            ]
        else:
            sources = [
                prompt_input.format_map(dict(instruction=source)) for source in sources
            ]
    sources = [
        source if data_name != "humaneval" else question
        for source, question, data_name in zip(sources, questions, examples["type"])
    ]
    if no_encode:
        return dict(prompt=sources)
    sources_tokenized = _tokenize_fn(
        sources, tokenizer, add_eos_token=False, return_offsets_mapping=True
    )
    input_ids = sources_tokenized["input_ids"]
    source_lens = sources_tokenized["input_ids_lens"]
    if model_args.amplify_factor == 0.0:
        return dict(input_ids=input_ids)
    k_token_mask, q_token_mask = get_qk_token_mask(
        model_args, input_ids, questions, sources, sources_tokenized
    )
    data_dict = dict(
        input_ids=input_ids, k_token_mask=k_token_mask, q_token_mask=q_token_mask
    )
    if (
        model_args.amplify_total_threshold_output
        or model_args.amplify_exclude_self
        or model_args.amplify_total_threshold_upper
        or model_args.amplify_topk
    ):
        data_dict["tgt_index"] = source_lens
    return data_dict


def process_dataset_attn_amplify_train(
    examples, tokenizer, data_args, model_args, add_eos_token=True
):
    prompt_input = PROMPT_DICT[data_args.inst_type]["prompt_input"]
    if "question" in examples:
        questions = examples["question"]
        sources = [
            INFERENCE_PATTERN[data_args.prompt_format][0].replace(
                "{question}", question
            )
            for question in questions
        ]
    else:
        questions = [
            e.replace("## Question", "").replace("Question:", "").strip()
            for e in examples["input"]
        ]
        sources = examples["input"]
    sources = [
        prompt_input.format_map(dict(instruction=input)) if len(input) > 0 else ""
        for input in sources
    ]
    if "generation" in examples:
        generations = [g[0] for g in examples["generation"]]
        if "[/INST]" in generations[0]:
            targets = [" " + g.split("[/INST]")[-1].strip() for g in generations]
        elif "## Solution" in generations[0]:
            targets = [g.split("## Solution")[-1].strip() for g in generations]
        else:
            targets = generations
    else:
        targets = [f"{output}" for output in examples["output"]]
    input_tokens = [s + t for s, t in zip(sources, targets)]
    input_tokenized = _tokenize_fn(
        input_tokens,
        tokenizer,
        return_offsets_mapping=True,
        add_eos_token=add_eos_token,
    )
    sources_tokenized = _tokenize_fn(sources, tokenizer, add_eos_token=False)
    source_lens = sources_tokenized["input_ids_lens"]
    input_ids = input_tokenized["input_ids"]
    labels = copy.deepcopy(input_ids)
    for label, source_len in zip(labels, source_lens):
        if source_len <= 1:
            source_len = 0
        label[:source_len] = IGNORE_INDEX
    if model_args.amplify_factor == 0.0:
        return dict(input_ids=input_ids, labels=labels)
    k_token_mask, q_token_mask = get_qk_token_mask(
        model_args,
        input_ids,
        questions,
        input_tokens,
        input_tokenized,
        is_train=True,
        solutions=targets,
    )
    data_dict = dict(
        input_ids=input_ids,
        k_token_mask=k_token_mask,
        q_token_mask=q_token_mask,
        labels=labels,
    )
    if (
        model_args.amplify_total_threshold_output
        or model_args.amplify_exclude_self
        or model_args.amplify_total_threshold_upper
        or model_args.amplify_topk
    ):
        data_dict["tgt_index"] = source_lens
    return data_dict


def get_input_id_indices(offset_mapping, start_end_pairs):
    token_indices_pairs = []
    for start, end in start_end_pairs:
        start_token_id = end_token_id = None
        for i, offset in enumerate(offset_mapping):
            if offset[0] <= start < offset[1]:
                start_token_id = i
            if offset[0] <= end < offset[1]:
                end_token_id = i
            if start_token_id is not None and end_token_id is not None:
                break
        if end_token_id is None:
            end_token_id = len(offset_mapping)
        token_indices_pairs.append((start_token_id, end_token_id))
    return token_indices_pairs


def get_token_positions(input_tokens, step_text, offset_mapping, multiple_ok=False):
    def find_all_substrings(input_string, substring):
        pairs = []
        start = 0
        while start < len(input_string):
            start = input_string.find(substring, start)
            if start == -1:  # No more occurrences found
                break
            pairs.append([start, start + len(substring)])
            start += len(substring)  # Move past the last found occurrence
        return pairs

    if multiple_ok:
        pairs = find_all_substrings(input_tokens, step_text)
    else:
        start_idx = input_tokens.find(step_text)
        assert start_idx != -1
        end_idx = start_idx + len(step_text)
        pairs = [(start_idx, end_idx)]
    token_positions = get_input_id_indices(offset_mapping, pairs)
    return token_positions


def process_dataset_generation(examples, tokenizer, data_args, model_args):
    prompt_input = PROMPT_DICT[data_args.inst_type]["prompt_input"]
    sources = [
        prompt_input.format_map(dict(instruction=input)) if len(input) > 0 else ""
        for input in examples["question"]
    ]
    generations = [g[0] for g in examples["generation"]]
    if "[/INST]" in generations[0]:
        targets = [g.split("[/INST]")[-1].strip() for g in generations]
    if "## Solution" in generations[0]:
        targets = [g.split("## Solution")[-1].strip() for g in generations]
    else:
        targets = generations
    data_dict = preprocess(sources, targets, tokenizer, data_args, model_args)
    return data_dict


def process_dataset_reasoning_graph(examples, tokenizer, data_args, model_args):
    prompt_input = PROMPT_DICT[data_args.inst_type]["prompt_input"]
    sources = [
        prompt_input.format_map(dict(instruction=input)) if len(input) > 0 else ""
        for input in examples["input"]
    ]
    targets = [f"{output}" for output in examples["output"]]
    questions = [e.replace("## Question", "").strip() for e in examples["input"]]
    input_tokens = [s + t for s, t in zip(sources, targets)]
    input_tokenized = _tokenize_fn(input_tokens, tokenizer, return_offsets_mapping=True)
    sources_tokenized = _tokenize_fn(sources, tokenizer, add_eos_token=False)
    source_lens = sources_tokenized["input_ids_lens"]
    input_ids = input_tokenized["input_ids"]
    labels = copy.deepcopy(input_ids)
    for label, source_len in zip(labels, source_lens):
        if source_len <= 1:
            source_len = 0
        label[:source_len] = IGNORE_INDEX

    all_reference_tuples = []
    all_step_tuples = []
    for i, (
        cur_input_ids,
        cur_input_tokens,
        points,
        question,
        offset_mapping,
    ) in enumerate(
        zip(
            input_ids,
            input_tokens,
            examples["points"],
            questions,
            input_tokenized["offset_mapping"],
        )
    ):
        for step_id in points:
            if points[step_id] is None:
                continue
            input_position = points[step_id]["input_position"]
            step_text = (
                points[step_id]["text"].strip()
                if input_position is None
                else question[input_position[0] : input_position[1]].strip()
            )
            step_index_tuples = get_token_positions(
                cur_input_tokens, step_text, offset_mapping, model_args.rg_multiple_ok
            )
            points[step_id]["token_position"] = step_index_tuples
        sample_reference_tuples = []
        sample_step_tuples = []
        for step_id in points:
            if points[step_id] is None:
                continue
            reference_index_tuples = []
            question_id = examples["last_question_id"][i]
            for reference_id in points[step_id]["reference_id"]:
                if model_args.rg_only_question_graph and reference_id > question_id:
                    continue
                reference_index_tuples.extend(
                    points[str(reference_id)]["token_position"]
                )
            if model_args.rg_force_ask and len(reference_index_tuples) > 0:
                question_id = examples["last_question_id"][i]
                if question_id not in points[step_id]["reference_id"]:
                    reference_index_tuples.extend(
                        points[str(question_id)]["token_position"]
                    )
            if len(reference_index_tuples) > 0:
                sample_reference_tuples.append(reference_index_tuples)
                # TODO maybe can for
                step_tuple = points[step_id]["token_position"][0]
                if model_args.rg_shift_mask:
                    step_tuple = [step_tuple[0] - 1, step_tuple[1] - 1]
                sample_step_tuples.append(step_tuple)
        all_reference_tuples.append(sample_reference_tuples)
        all_step_tuples.append(sample_step_tuples)

    return dict(
        input_ids=input_ids,
        labels=labels,
        all_reference_tuples=all_reference_tuples,
        all_step_tuples=all_step_tuples,
        amplify_factor=[model_args.amplify_factor] * len(input_ids),
    )


def process_dataset_grad(examples, tokenizer, data_args, model_args):
    prompt_input = PROMPT_DICT[data_args.inst_type]["prompt_input"]
    if "question" in examples:
        questions = examples["question"]
        sources = [
            INFERENCE_PATTERN[data_args.prompt_format][0].replace(
                "{question}", question
            )
            for question in questions
        ]
    else:
        questions = [e.replace("## Question", "").strip() for e in examples["input"]]
        sources = examples["input"]
    sources = [
        prompt_input.format_map(dict(instruction=input)) if len(input) > 0 else ""
        for input in sources
    ]
    targets = [f"{output}" for output in examples["output"]]
    input_tokens = [s + t for s, t in zip(sources, targets)]
    input_tokenized = _tokenize_fn(input_tokens, tokenizer, return_offsets_mapping=True)
    sources_tokenized = _tokenize_fn(sources, tokenizer, add_eos_token=False)
    source_lens = sources_tokenized["input_ids_lens"]
    input_ids = input_tokenized["input_ids"]
    labels = copy.deepcopy(input_ids)
    for label, source_len in zip(labels, source_lens):
        if source_len <= 1:
            source_len = 0
        label[:source_len] = IGNORE_INDEX
    solutions = [source.split("## Solution")[-1].strip() for source in input_tokens]
    substrings = [
        [question, solution] for question, solution in zip(questions, solutions)
    ]
    token_mask = get_token_mask(
        model_args,
        input_ids,
        input_tokens,
        substrings,
        offset_mappings=input_tokenized["offset_mapping"],
        mask_value=True,
    )
    return dict(
        input_ids=input_ids,
        token_mask=token_mask,
        labels=labels,
    )


def process_dataset_repeat(examples, tokenizer, data_args, model_args):
    prompt_input = PROMPT_DICT[data_args.inst_type]["prompt_input"]
    questions = examples["question"]
    sources = [
        INFERENCE_PATTERN[data_args.prompt_format][0].replace("{question}", question)
        for question in questions
    ]
    sources = [prompt_input.format_map(dict(instruction=source)) for source in sources]
    sources_tokenized = _tokenize_fn(
        sources, tokenizer, add_eos_token=False, return_offsets_mapping=True
    )
    input_ids = sources_tokenized["input_ids"]
    attention_mask = []
    for i, (cur_input_ids, cur_input_tokens, question, offset_mapping) in enumerate(
        zip(input_ids, sources, questions, sources_tokenized["offset_mapping"])
    ):
        cur_attention_mask = cur_input_ids.ne(tokenizer.pad_token_id).to(torch.long)
        step_index_tuples = get_token_positions(
            cur_input_tokens, question, offset_mapping, multiple_ok=False
        )
        cur_attention_mask[step_index_tuples[0][0] : step_index_tuples[0][1]] = -1
        attention_mask.append(cur_attention_mask)

    return dict(
        input_ids=input_ids,
        attention_mask=attention_mask,
    )
