import json
import os
from dataclasses import dataclass, field
from typing import Dict, Optional
import random

import torch
import torch.distributed as dist
import transformers
import datasets
from transformers import (
    LlamaForCausalLM,
    Trainer,
    set_seed,
    AutoConfig,
    AutoModelForCausalLM,
    MistralForCausalLM,
)
from datasets import disable_caching

disable_caching()

from src.utils.file_utils import load_jsonl, load_jsonl_ml
from src.data.process import (
    process_dataset,
    process_dataset_cringe,
    process_dataset_cringe_gsm,
    process_dataset_mle_aug,
    process_dataset_cl,
    process_dataset_trans,
    process_dataset_cls,
    process_dataset_rank,
    process_dataset_rank_com,
    process_dataset_choice,
    process_dataset_choice_cls,
    process_dataset_reasoning_graph,
    process_dataset_attn_amplify_train,
    process_dataset_generation,
    process_dataset_grad,
)
from src.data.collator import (
    DataCollatorForSupervisedDataset,
    DataCollatorForSupervisedDatasetCls,
    DataCollatorForSupervisedDatasetCringe,
    DataCollatorForSupervisedDatasetMLEAug,
    DataCollatorForSupervisedDatasetCL,
    DataCollatorForSupervisedDatasetQTrans,
    DataCollatorForSupervisedDatasetRank,
    DataCollatorForSupervisedDatasetRankCom,
    DataCollatorForSupervisedDatasetChoice,
    DataCollatorForSupervisedDatasetChoiceCls,
    DataCollatorForSupervisedDatasetReasonGraph,
    DataCollatorForSupervisedDatasetAttn,
    DataCollatorForSupervisedDatasetGrad,
)
from src.model.cringe_llama import CringeLlamaForCausalLM
from src.model.modeling_mle_aug import MLEAugLlamaForCausalLM
from src.model.modeling_gold import GoldLlamaForCausalLM
from src.model.modeling_cl import ContraLlamaForCausalLM
from src.model.modeling_cls import ClsLlamaForCausalLM
from src.model.modeling_qtrans_llama import (
    QTransLlamaForCausalLM,
    QTransBaseLlamaForCausalLM,
)
from src.model.modeling_rank import AlignmentLlamaForCausalLM
from src.model.modeling_rank_com import AlignmentComLlamaForCausalLM
from src.model.modeling_choice import ChoiceLlamaForCausalLM
from src.model.modeling_choice_cls import LlamaForCausalLMChoiceCls
from src.model.reason_graph_trainer import ReasonGraphTrainer
from src.model.attn_amplify_trainer import AttnAmplifyTrainer
from src.model.grad_trainer import GradTrainer
from src.model.modeling_llama_attn import LlamaForCausalLMAttn
from src.model.modeling_llama_uncert_attn import LlamaForCausalLMUncertAttn
from src.model.modeling_llama_adapt_mask import LlamaForCausalLMAdaptMask
from src.model.modeling_llama_grad import LlamaForCausalLMGrad
from src.model.configuration_llama_attn import LlamaConfigAttn
from src.model.configuration_llama_adapt_mask import LlamaConfigAdaptMask
from src.model.concrete_gate import ConcreteGate

# from src.model.modeling_mistral import MistralForCausalLM
from src.model.modeling_mistral_attn import MistralForCausalLMAttn
from src.model.modeling_mistral_uncert_attn import MistralForCausalLMUncertAttn
from src.model.configuration_mistral_attn import MistralConfigAttn
from src.model.phi.modeling_phi import PhiForCausalLM
from src.utils.train_utils import (
    GumbelTempCallback,
    smart_tokenizer_and_embedding_resize,
    NoShuffleSeq2SeqTrainer,
)

DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "<s>"
DEFAULT_UNK_TOKEN = "<unk>"
# PROMPT_DICT = {
#     # "prompt_input": (
#     #     "Below is an instruction that describes a task, paired with an input that provides further context. "
#     #     "Write a response that appropriately completes the request.\n\n"
#     #     "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
#     # ),
#     # "prompt_no_input": (
#     #     "Below is an instruction that describes a task. "
#     #     "Write a response that appropriately completes the request.\n\n"
#     #     "### Instruction:\n{instruction}\n\n### Response:"
#     # ),
#     "prompt_input": ("{instruction}\n\n### Response:"),
#     "prompt_no_input": ("{instruction}\n\n### Response:"),
# }


@dataclass
class ModelArguments:
    model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
    tokenizer_path: str = field(default=None)
    cringe_type: str = field(default=None)  # vanilla, gsm
    cringe_add_reduce_loss: bool = field(default=False)
    mle_aug: str = field(default=None)  # None, cot_pal, cot_pal_expr
    mle_aug_norm: bool = field(default=False)
    gold: bool = field(default=False)
    gold_alpha: float = field(default=0.5)
    gold_beta: float = field(default=0.05)
    cl: bool = field(default=False)
    cl_data_type: str = field(default=None)
    cl_weight: float = field(default=0.0)
    rank_weight: float = field(default=0.0)
    cl_length_penalty: int = field(default=1)
    pooler_type: str = field(default="mean_no_mlp")  # mean mean_no_mlp last
    use_bn: bool = field(default=False)
    temp: float = field(default=0.05)
    rank_margin: float = field(default=0.0)
    convert_bf: bool = field(default=False)
    single_modal: bool = field(default=False)
    no_hard_neg: bool = field(default=False)
    no_add_pos: bool = field(default=False)
    trans_question: bool = field(default=False)
    trans_hard: bool = field(default=False)
    trans_self: bool = field(default=False)
    trans_freeze: bool = field(default=False)
    trans_sep_lm_head: bool = field(default=False)
    trans_softmax: bool = field(default=False)
    trans_model_path: str = field(default=None)
    # copy_question: bool = field(default=False)
    do_cls: bool = field(default=False)
    cls_weight: float = field(default=0.0)
    cls_data_type: str = field(default=None)
    rank_type: str = field(default=None)
    rank_pooler_type: str = field(default="seq_prob")
    rank_detach: bool = field(default=False)
    rank_com_type: str = field(default=None)
    rank_com_pooler_type: str = field(default="seq_prob")
    rank_com_weight: float = field(default=0.0)
    add_tokens_path: str = field(default=None)
    add_special_tokens: bool = field(default=False)
    choice_loss: bool = field(default=False)
    choice_weight: float = field(default=1.0)
    do_curriculum: bool = field(default=False)
    choice_cls: bool = field(default=False)
    rg_weight: float = field(default=0.0)
    rg_teacher_type: str = field(default="reweight")
    rg_distill_loss: str = field(default="kl")
    rg_only_question: bool = field(default=False)
    rg_only_question_graph: bool = field(default=False)
    rg_align: bool = field(default=False)
    rg_batch_avg: bool = field(default=False)
    rg_force_ask: bool = field(default=False)
    rg_multiple_ok: bool = field(default=False)
    rg_shift_mask: bool = field(default=False)
    rg_shift_layer: int = field(default=15)
    rg_amplify: bool = field(default=False)
    rg_grad: str = field(default=None)  # layer
    rg_grad_abs: str = field(default=None)  # before, after
    rg_grad_topk: int = field(default=10)
    rg_grad_norm: str = field(default=None)  # l1, l2, minmax, z
    rg_grad_head_score: str = field(default=None)  # qk, v
    # post attn
    do_post_attn: bool = field(default=False)
    do_reason_graph: bool = field(default=False)
    amplify_factor: float = field(default=0.0)
    use_target_mask: bool = field(default=False)
    amplify_k_scope: str = field(
        default="question"
    )  # all / question / question_solution
    amplify_q_scope: str = field(default="all")  # all / solution / question_solution
    amplify_topk: int = field(default=None)
    amplify_topk_head: int = field(default=None)
    amplify_head_id: int = field(default=None)
    amplify_layer_head_path: str = field(default=None)
    amplify_layer_path: str = field(default=None)
    amplify_reverse: bool = field(default=False)
    amplify_total_topk: int = field(default=None)
    amplify_decay: float = field(default=None)
    amplify_total_threshold: float = field(default=None)
    amplify_total_threshold_output: float = field(default=None)
    amplify_total_threshold_upper: float = field(default=None)
    amplify_upper_position: int = field(default=None)
    amplify_upper_question: bool = field(default=False)
    amplify_layer_threshold: bool = field(default=False)
    amplify_adapt_threshold: str = field(default=None)  # mean_std / mean
    amplify_adapt_threshold_factor: float = field(default=1.0)
    amplify_adapt_threshold_factor_upper: float = field(default=None)
    amplify_norm_threshold: bool = field(default=False)
    amplify_output_factor: float = field(default=None)
    amplify_exclude_self: bool = field(default=False)
    amplify_exclude_cal: str = field(default=None)
    amplify_uncert_threshold: float = field(default=None)
    amplify_uncert_type: str = field(default="constant")  # constant / neg_logp
    amplify_uncert_constant: float = field(default=1.0)
    amplify_uncert_upper: float = field(default=1.0)
    amplify_uncert_score_type: str = field(default="neg_logp")  # delta
    amplify_skip_stopwords: str = field(default=None)
    amplify_skip_penalty: float = field(default=1.0)
    amplify_use_sink: bool = field(default=False)
    amplify_smooth_window: int = field(default=None)
    amplify_soft_mask: int = field(default=None)
    amplify_soft_mask_activation: bool = field(default=False)
    amplify_uncert_percentage: float = field(default=None)
    amplify_uncert_percentage_output: float = field(default=None)
    amplify_upper_apply_input: bool = field(default=False)
    rg_multiple_ok: bool = field(default=False)
    start_layer: int = field(default=0)
    end_layer: int = field(default=31)
    keep_only_attn: bool = field(default=False)
    do_adapt_attn_mask: bool = field(default=False)
    cnn_kernel_size: int = field(default=3)
    do_repeat_attn: bool = field(default=False)
    few_shot: bool = field(default=False)
    flash_attention: bool = field(default=False)


@dataclass
class DataArguments:
    data_path: str = field(
        default=None, metadata={"help": "Path to the training data."}
    )
    apply_src_loss: bool = field(default=False)
    apply_partial_tgt_loss: bool = field(default=False)
    inst_type: str = field(default="inst")
    remove_pal_question: bool = field(default=False)
    copy_question: bool = field(default=False)
    max_samples: int = field(default=None)
    result_path: str = field(default=None)
    prompt_format: str = field(default=None)
    preprocessing_num_workers: int = field(default=32)
    only_output: bool = field(default=False)


@dataclass
class TrainingArguments(transformers.TrainingArguments):
    cache_dir: Optional[str] = field(default=None)
    optim: str = field(default="adamw_torch")
    model_max_length: int = field(
        default=512,
        metadata={
            "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
        },
    )
    trans_initial_tau: float = field(default=2.0)
    trans_minimum_tau: float = field(default=0.0)
    trans_anneal_rate: float = field(default=3e-4)


def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
    """Collects the state dict and dump to disk."""
    state_dict = trainer.model.state_dict()
    if trainer.args.should_save:
        cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
        del state_dict
        trainer._save(output_dir, state_dict=cpu_state_dict)  # noqa


def make_supervised_data_module(
    tokenizer: transformers.PreTrainedTokenizer,
    data_args,
    training_args,
    model_args,
    is_eval=False,
) -> Dict:
    """Make dataset and collator for supervised fine-tuning."""

    if data_args.data_path.endswith(".json") or data_args.data_path.endswith(".jsonl"):
        try:
            samples = load_jsonl(data_args.data_path)
        except:
            samples = load_jsonl_ml(data_args.data_path)
        raw_train_dataset = datasets.Dataset.from_list(samples)
    elif os.path.exists(os.path.join(data_args.data_path, "dataset_dict.json")):
        # if "gsm" in data_args.data_path and "gsm8k_cringe" not in data_args.data_path:
        raw_train_dataset = datasets.DatasetDict.load_from_disk(data_args.data_path)[
            "train"
        ]
    else:
        raw_train_dataset = datasets.Dataset.load_from_disk(data_args.data_path)
    # train_dataset = train_dataset.select(range(10))
    if data_args.max_samples is not None:
        raw_train_dataset = raw_train_dataset.select(range(data_args.max_samples))
    if data_args.only_output:
        raw_train_dataset = raw_train_dataset.map(
            lambda x: {"input": ""},
            num_proc=data_args.preprocessing_num_workers,
        )

    if model_args.cringe_type == "vanilla":
        process_dataset_fn = process_dataset_cringe
        collator_class = DataCollatorForSupervisedDatasetCringe
    elif model_args.cringe_type == "gsm":
        process_dataset_fn = process_dataset_cringe_gsm
        collator_class = DataCollatorForSupervisedDatasetCringe
    elif model_args.mle_aug is not None:
        model_args.data_type = model_args.mle_aug
        process_dataset_fn = process_dataset_mle_aug
        collator_class = DataCollatorForSupervisedDatasetMLEAug
    elif model_args.cl:
        model_args.data_type = model_args.cl_data_type
        process_dataset_fn = process_dataset_cl
        collator_class = DataCollatorForSupervisedDatasetCL
    elif model_args.trans_question:
        process_dataset_fn = process_dataset_trans
        collator_class = DataCollatorForSupervisedDatasetQTrans
    elif model_args.do_cls:
        model_args.data_type = model_args.cls_data_type
        process_dataset_fn = process_dataset_cls
        collator_class = DataCollatorForSupervisedDatasetCls
    elif model_args.rank_com_type is not None:
        process_dataset_fn = process_dataset_rank_com
        collator_class = DataCollatorForSupervisedDatasetRankCom
    elif model_args.rank_type is not None:
        process_dataset_fn = process_dataset_rank
        collator_class = DataCollatorForSupervisedDatasetRank
    elif model_args.choice_loss:
        process_dataset_fn = process_dataset_choice
        collator_class = DataCollatorForSupervisedDatasetChoice
    elif model_args.choice_cls:
        process_dataset_fn = process_dataset_choice_cls
        collator_class = DataCollatorForSupervisedDatasetChoiceCls
    elif model_args.rg_weight > 0.0:
        if (
            model_args.rg_only_question
            or model_args.rg_amplify
            or model_args.rg_grad_head_score is not None
            or model_args.amplify_soft_mask is not None
        ):
            process_dataset_fn = process_dataset_attn_amplify_train
            collator_class = DataCollatorForSupervisedDatasetAttn
        elif model_args.rg_grad is not None:
            process_dataset_fn = process_dataset_grad
            collator_class = DataCollatorForSupervisedDatasetGrad
        else:
            process_dataset_fn = process_dataset_reasoning_graph
            collator_class = DataCollatorForSupervisedDatasetReasonGraph
    elif (
        model_args.do_post_attn
        or model_args.rg_grad_head_score is not None
        or model_args.amplify_soft_mask is not None
    ):
        if model_args.do_reason_graph:
            process_dataset_fn = process_dataset_reasoning_graph
            collator_class = DataCollatorForSupervisedDatasetReasonGraph
        else:
            process_dataset_fn = process_dataset_attn_amplify_train
            collator_class = DataCollatorForSupervisedDatasetAttn
    else:
        process_dataset_fn = process_dataset
        collator_class = DataCollatorForSupervisedDataset

    print("process_fn", process_dataset_fn)

    fn_kwargs = {
        "tokenizer": tokenizer,
        "data_args": data_args,
        "model_args": model_args,
    }

    if hasattr(training_args, "local_rank") and training_args.local_rank > 0:
        torch.distributed.barrier()
    train_dataset = raw_train_dataset.map(
        process_dataset_fn,
        batched=True,
        num_proc=data_args.preprocessing_num_workers,
        remove_columns=raw_train_dataset.column_names,
        desc="Running tokenizer on train dataset",
        fn_kwargs=fn_kwargs,
    )

    if hasattr(training_args, "local_rank") and training_args.local_rank == 0:
        torch.distributed.barrier()

    print(len(train_dataset))
    if (hasattr(training_args, "local_rank") and training_args.local_rank == 0) or (
        dist.is_initialized() and dist.get_rank() == 0
    ):
        for index in random.sample(range(len(train_dataset)), 3):
            print(f"Sample {index} of the training set: {train_dataset[index]}.")
            if isinstance(train_dataset[index]["input_ids"][0], list):
                print(tokenizer.decode(train_dataset[index]["input_ids"][0]))
            else:
                print(tokenizer.decode(train_dataset[index]["input_ids"]))

    data_collator = collator_class(tokenizer=tokenizer)
    if is_eval:
        return dict(
            train_dataset=train_dataset,
            eval_dataset=None,
            data_collator=data_collator,
            samples=raw_train_dataset,
        )
    return dict(
        train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator
    )


def get_model_tokenizer(model_args, data_args, training_args, is_eval=False):
    model_class, config_class = None, None
    if model_args.cringe_type is not None:
        model_class = CringeLlamaForCausalLM
    elif model_args.mle_aug is not None:
        model_class = MLEAugLlamaForCausalLM
    elif model_args.gold:
        model_class = GoldLlamaForCausalLM
    elif model_args.cl:
        model_class = ContraLlamaForCausalLM
    elif data_args.copy_question:
        model_class = QTransBaseLlamaForCausalLM
    elif model_args.trans_question:
        model_class = QTransLlamaForCausalLM
    elif model_args.do_cls:
        model_class = ClsLlamaForCausalLM
    elif model_args.rank_com_type is not None:
        model_class = AlignmentComLlamaForCausalLM
    elif model_args.rank_type is not None:
        model_class = AlignmentLlamaForCausalLM
    elif model_args.choice_loss:
        model_class = ChoiceLlamaForCausalLM
    elif model_args.choice_cls:
        model_class = LlamaForCausalLMChoiceCls
    elif model_args.rg_grad is not None:
        model_class = LlamaForCausalLMGrad
    elif (
        model_args.do_post_attn
        or model_args.rg_amplify
        or model_args.rg_grad_head_score is not None
        or model_args.amplify_soft_mask is not None
    ):
        if "Mistral" in model_args.model_name_or_path:
            # model_class = MistralForCausalLMAttn
            model_class = MistralForCausalLMUncertAttn
            config_class = MistralConfigAttn
        else:
            # model_class = LlamaForCausalLMAttn
            model_class = LlamaForCausalLMUncertAttn
            config_class = LlamaConfigAttn
    elif model_args.do_adapt_attn_mask:
        model_class = LlamaForCausalLMAdaptMask
        config_class = LlamaConfigAdaptMask

    # elif model_args.copy_question:  # TODO: order

    if config_class is not None:
        if (
            model_args.do_post_attn
            or model_args.rg_amplify
            or model_args.rg_grad_head_score is not None
        ):
            if model_args.amplify_layer_head_path is not None:
                amplify_layer_head = json.load(
                    open(model_args.amplify_layer_head_path, "r")
                )
            else:
                amplify_layer_head = None
            if model_args.amplify_layer_path is not None:
                amplify_layer = json.load(open(model_args.amplify_layer_path, "r"))
            else:
                amplify_layer = None
            config = config_class.from_pretrained(
                model_args.model_name_or_path,
                amplify_topk=model_args.amplify_topk,
                amplify_k_scope=model_args.amplify_k_scope,
                amplify_q_scope=model_args.amplify_q_scope,
                amplify_topk_head=model_args.amplify_topk_head,
                amplify_head_id=model_args.amplify_head_id,
                amplify_layer_head=amplify_layer_head,
                amplify_layer=amplify_layer,
                amplify_reverse=model_args.amplify_reverse,
                amplify_total_topk=model_args.amplify_total_topk,
                amplify_total_threshold=model_args.amplify_total_threshold,
                amplify_total_threshold_output=model_args.amplify_total_threshold_output,
                amplify_total_threshold_upper=model_args.amplify_total_threshold_upper,
                amplify_upper_position=model_args.amplify_upper_position,
                amplify_upper_question=model_args.amplify_upper_question,
                amplify_layer_threshold=model_args.amplify_layer_threshold,
                amplify_adapt_threshold=model_args.amplify_adapt_threshold,
                amplify_adapt_threshold_factor=model_args.amplify_adapt_threshold_factor,
                amplify_adapt_threshold_factor_upper=model_args.amplify_adapt_threshold_factor_upper,
                amplify_norm_threshold=model_args.amplify_norm_threshold,
                amplify_output_factor=model_args.amplify_output_factor,
                amplify_decay=model_args.amplify_decay,
                amplify_exclude_self=model_args.amplify_exclude_self,
                amplify_exclude_cal=model_args.amplify_exclude_cal,
                amplify_uncert_threshold=model_args.amplify_uncert_threshold,
                amplify_uncert_type=model_args.amplify_uncert_type,
                amplify_uncert_constant=model_args.amplify_uncert_constant,
                amplify_uncert_upper=model_args.amplify_uncert_upper,
                amplify_uncert_score_type=model_args.amplify_uncert_score_type,
                amplify_skip_stopwords=model_args.amplify_skip_stopwords,
                amplify_skip_penalty=model_args.amplify_skip_penalty,
                amplify_use_sink=model_args.amplify_use_sink,
                amplify_factor=model_args.amplify_factor,
                amplify_smooth_window=model_args.amplify_smooth_window,
                amplify_soft_mask=model_args.amplify_soft_mask,
                amplify_soft_mask_activation=model_args.amplify_soft_mask_activation,
                amplify_uncert_percentage=model_args.amplify_uncert_percentage,
                amplify_uncert_percentage_output=model_args.amplify_uncert_percentage_output,
                amplify_upper_apply_input=model_args.amplify_upper_apply_input,
                start_layer=model_args.start_layer,
                end_layer=model_args.end_layer,
                rg_grad_head_score=model_args.rg_grad_head_score,
            )
            if hasattr(config, "_attn_implementation"):
                config._attn_implementation = "eager"
        elif model_args.do_adapt_attn_mask:
            config = config_class.from_pretrained(
                model_args.model_name_or_path,
                cnn_kernel_size=model_args.cnn_kernel_size,
            )

    kwargs = {}
    if is_eval:
        kwargs["local_files_only"] = False
        kwargs["torch_dtype"] = torch.bfloat16
    if model_args.trans_question and not model_args.trans_self:
        kwargs["trans_config"] = AutoConfig.from_pretrained(model_args.trans_model_path)
    if config_class is not None:
        kwargs["config"] = config
    kwargs["is_eval"] = is_eval

    if model_class is not None:
        if (
            model_args.do_post_attn
            or model_args.rg_amplify
            or model_args.do_adapt_attn_mask
            or model_args.rg_grad_head_score is not None
        ):
            del kwargs["is_eval"]
            model = model_class.from_pretrained(
                model_args.model_name_or_path,
                cache_dir=training_args.cache_dir,
                **kwargs,
            )
        elif not data_args.copy_question:
            model = model_class.from_pretrained(
                model_args.model_name_or_path,
                cache_dir=training_args.cache_dir,
                model_args=model_args,
                **kwargs,
            )
        else:
            model = model_class.from_pretrained(
                model_args.model_name_or_path,
                cache_dir=training_args.cache_dir,
                **kwargs,
            )
    else:
        del kwargs["is_eval"]
        if "mistral" in model_args.model_name_or_path.lower():
            model = MistralForCausalLM.from_pretrained(
                model_args.model_name_or_path,
                cache_dir=training_args.cache_dir,
                attn_implementation=(
                    "flash_attention_2" if model_args.flash_attention else None
                ),
                torch_dtype="auto",
                **kwargs,
            )
        elif "phi" in model_args.model_name_or_path:
            model = PhiForCausalLM.from_pretrained(
                model_args.model_name_or_path,
                torch_dtype=torch.bfloat16,
                cache_dir=training_args.cache_dir,
                attn_implementation=(
                    "flash_attention_2" if model_args.flash_attention else None
                ),
                **kwargs,
            )
        elif "stable" in model_args.model_name_or_path.lower():
            model = AutoModelForCausalLM.from_pretrained(
                model_args.model_name_or_path,
                torch_dtype="auto",
                cache_dir=training_args.cache_dir,
                attn_implementation=(
                    "flash_attention_2" if model_args.flash_attention else None
                ),
                trust_remote_code=True,
            )
        else:
            model = LlamaForCausalLM.from_pretrained(
                model_args.model_name_or_path,
                cache_dir=training_args.cache_dir,
                **kwargs,
            )

    if model_args.trans_question and not model_args.trans_self and not is_eval:
        if "trans_config" in kwargs:
            del kwargs["trans_config"]
            del kwargs["is_eval"]
        model.trans_model = LlamaForCausalLM.from_pretrained(
            model_args.trans_model_path, **kwargs
        )

    if model_args.keep_only_attn:
        for name, param in model.named_parameters():
            if "q_proj" in name or "k_proj" in name:
                param.requires_grad = False

    if (hasattr(training_args, "local_rank") and training_args.local_rank == 0) or (
        dist.is_initialized() and dist.get_rank() == 0
    ):
        print("Model parameters:", model.num_parameters())

    tokenizer = transformers.AutoTokenizer.from_pretrained(
        (
            model_args.model_name_or_path
            if model_args.tokenizer_path is None
            else model_args.tokenizer_path
        ),
        cache_dir=training_args.cache_dir,
        model_max_length=training_args.model_max_length,
        padding_side="right" if not is_eval else "left",
        use_fast=(
            True
            if model_args.cringe_type == "gsm"
            or model_args.trans_question
            or model_args.rg_weight > 0.0
            or model_args.rg_grad_head_score is not None
            or model_args.amplify_soft_mask is not None
            or "stable" in model_args.model_name_or_path
            else False
        ),
        trust_remote_code=True,
    )

    if tokenizer.pad_token is None:
        if "codegen" in model_args.model_name_or_path or "deepseek" in model_args.model_name_or_path:
            tokenizer.pad_token = tokenizer.eos_token
            tokenizer.pad_token_id = tokenizer.eos_token_id
            model.config.pad_token = tokenizer.eos_token
            model.config.pad_token_id = tokenizer.eos_token_id
        else:
            tokenizer.pad_token = tokenizer.unk_token
            tokenizer.pad_token_id = tokenizer.unk_token_id
            model.config.pad_token = tokenizer.unk_token
            model.config.pad_token_id = tokenizer.unk_token_id

    if (
        model_args.amplify_exclude_cal is not None
        or model_args.amplify_skip_stopwords is not None
        or model_args.amplify_uncert_threshold is not None
    ):
        model.tokenizer = tokenizer

    if model_args.add_tokens_path is not None:
        with open(model_args.add_tokens_path, "r") as f:
            add_tokens = json.load(f)
        smart_tokenizer_and_embedding_resize(
            add_tokens, tokenizer, model, is_special=model_args.add_special_tokens
        )

    if model_args.rg_grad_head_score is not None and (
        "concrete" in model_args.rg_grad_head_score
        or "discrete" in model_args.rg_grad_head_score
    ):
        model.enable_input_require_grads()
        for param in model.parameters():
            if param.requires_grad:
                param.requires_grad = False
        if "concrete" in model_args.rg_grad_head_score:
            for module in model.modules():
                if isinstance(module, ConcreteGate):
                    for param in module.parameters():
                        param.requires_grad = True
        elif "discrete" in model_args.rg_grad_head_score:
            for layer in model.model.layers:
                layer.self_attn.soft_gate.requires_grad = True
    if model_args.amplify_soft_mask is not None:
        model.enable_input_require_grads()
        for param in model.parameters():
            if param.requires_grad:
                param.requires_grad = False
        for layer in model.model.layers:
            for param in layer.self_attn.soft_threshold.parameters():
                param.requires_grad = True

    return model, tokenizer


def train():
    parser = transformers.HfArgumentParser(
        (ModelArguments, DataArguments, TrainingArguments)
    )
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    model, tokenizer = get_model_tokenizer(model_args, data_args, training_args)
    set_seed(training_args.seed)
    random.seed(training_args.seed)

    data_module = make_supervised_data_module(
        tokenizer=tokenizer,
        data_args=data_args,
        training_args=training_args,
        model_args=model_args,
    )
    # from LLaMA-X: Tell Trainer not to attempt DataParallel
    model.is_parallelizable = True
    model.model_parallel = True
    callbacks = []
    if model_args.trans_question:
        callbacks.append(GumbelTempCallback)
    if model_args.do_curriculum:
        trainer_class = NoShuffleSeq2SeqTrainer
    elif model_args.rg_weight > 0.0:
        if model_args.rg_amplify:
            trainer_class = AttnAmplifyTrainer
        elif model_args.rg_grad is not None:
            trainer_class = GradTrainer
        else:
            trainer_class = ReasonGraphTrainer
        data_module["model_args"] = model_args
    else:
        trainer_class = Trainer
    print("trainer_class", trainer_class)
    trainer = trainer_class(
        model=model,
        tokenizer=tokenizer,
        args=training_args,
        callbacks=callbacks if len(callbacks) > 0 else None,
        **data_module,
    )
    model.config.use_cache = False
    # torch.autograd.set_detect_anomaly(True)
    trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
    trainer.save_state()
    # safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)
    trainer.save_model(output_dir=training_args.output_dir)

    if training_args.local_rank == 0:
        config_path = os.path.join(training_args.output_dir, "config.json")
        with open(config_path, "r") as f:
            config_dict = json.load(f)
            if "LlamaForCausalLM" not in config_dict["architectures"]:
                config_dict["architectures"] = ["LlamaForCausalLM"]
        with open(config_path, "w") as f:
            json.dump(config_dict, f, indent=4)


if __name__ == "__main__":
    train()
