from typing import Callable, Dict, List, Optional, Tuple, Union, Any
import torch
from torch import nn
from torch.utils.data import Dataset
from transformers import Trainer
from transformers.data.data_collator import DataCollator
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalPrediction
from transformers.training_args import TrainingArguments


def entropy_loss(attn_weights, token_mask):
    log_probs = torch.log(
        attn_weights + torch.finfo(attn_weights.dtype).eps
    )  # Adding a small value to prevent log(0)
    # entropy = -torch.sum((attn_weights * log_probs).masked_fill(~token_mask, 0), dim=-1)
    masked_probs = attn_weights * log_probs
    if token_mask is not None:
        masked_probs[~token_mask.expand(-1, masked_probs.shape[1], -1, -1)] = 0
    entropy = -torch.sum(masked_probs, dim=-1)
    return entropy.mean()


def entropy_loss_v2(attn_weights, token_mask):
    norm_attn_weights = attn_weights.masked_fill(~token_mask, 0)
    norm_attn_weights = norm_attn_weights / attn_weights.sum(dim=-1, keepdim=True)
    log_probs = torch.log(
        norm_attn_weights + torch.finfo(norm_attn_weights.dtype).eps
    )  # Adding a small value to prevent log(0)
    entropy = -torch.sum((norm_attn_weights * log_probs), dim=-1)
    return entropy.mean()


def entropy_loss_v3(attn_weights, token_mask):
    log_probs = torch.log(
        attn_weights + torch.finfo(attn_weights.dtype).eps
    )  # Adding a small value to prevent log(0)
    return -log_probs.masked_fill(~token_mask, 0).sum(dim=-1).mean()


def negative_p_norm_loss(attention_weights, p, token_mask):
    p_norm = torch.norm(attention_weights.masked_fill(~token_mask, 0), p=p, dim=-1)
    return -p_norm.mean()


class AttnAmplifyTrainer(Trainer):
    def __init__(self, model_args=None, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.model_args = model_args

    def kl_loss(self, student_attn, teacher_attn, epsilon=1e-10):
        kl_loss_fn = nn.KLDivLoss(reduction="none")
        # Handle NaN
        teacher_attn = teacher_attn.clamp(min=epsilon)
        student_attn = student_attn.clamp(min=epsilon)
        teacher_attn = teacher_attn / teacher_attn.sum(dim=-1, keepdim=True)
        student_attn = student_attn / student_attn.sum(dim=-1, keepdim=True)
        loss = kl_loss_fn(
            student_attn.log(), teacher_attn
        )  # Shape: (bsz, q_len, kv_seq_len)
        return loss

    def mse_loss(self, student_attn, teacher_attn):
        mse_loss_fn = nn.MSELoss(reduction="none")
        loss = mse_loss_fn(student_attn, teacher_attn)
        return loss

    def cl_loss(self, attn_weights, token_mask):
        attn_weights_flat = attn_weights.view(-1, attn_weights.size(-1))
        mask_flat = token_mask.view(-1, token_mask.size(-1))
        log_attn_weights = torch.log(
            attn_weights_flat + torch.finfo(attn_weights_flat.dtype).eps
        )
        positive_neg_log_probs = -(log_attn_weights * mask_flat)
        mask_sum = mask_flat.sum(dim=-1, keepdim=True)
        robust_mask_sum = torch.where(
            mask_sum == 0, torch.ones_like(mask_sum), mask_sum
        )
        loss = positive_neg_log_probs.sum(dim=-1) / robust_mask_sum
        loss = loss.mean()
        return loss

    def cl_loss_v2(
        self, attn_weights, query_states, key_states, token_mask, temperature
    ):
        attn_weights = (
            torch.matmul(query_states, key_states.transpose(2, 3)) / temperature
        )
        log_attn_weights = nn.functional.log_softmax(
            attn_weights, dim=-1, dtype=torch.float32
        ).to(query_states.dtype)
        log_attn_weights = log_attn_weights.view(-1, attn_weights.size(-1))
        mask_flat = token_mask.view(-1, token_mask.size(-1))
        positive_log_probs = log_attn_weights.masked_select(mask_flat != 0)
        loss = -positive_log_probs.mean()
        return loss

    def get_distill_loss(self, student_attn, teacher_attn, ele_mask):
        if self.model_args.rg_distill_loss == "kl":
            loss = self.kl_loss(student_attn, teacher_attn)
        elif self.model_args.rg_distill_loss == "mse":
            loss = self.mse_loss(student_attn, teacher_attn)
        if self.model_args.rg_batch_avg:
            # first average in each batch
            loss = (loss * ele_mask).view(loss.shape[0], -1).sum(
                dim=-1
            ) / ele_mask.view(loss.shape[0], -1).sum(dim=-1)
            loss = loss.mean()
        else:
            loss = loss.masked_fill(~ele_mask, 0).sum() / ele_mask.sum()
        return loss

    def get_teacher_attn(self, attention, token_mask):
        # KL divergence between student attention and teacher attention
        if self.model_args.rg_teacher_type == "reweight":
            teacher_attn = attention * (1 + token_mask)
            teacher_attn = teacher_attn / teacher_attn.sum(dim=-1, keepdim=True)
        elif self.model_args.rg_teacher_type == "oracle":
            mask_sum = token_mask.sum(dim=-1, keepdim=True)
            robust_mask_sum = torch.where(
                mask_sum == 0, torch.ones_like(mask_sum), mask_sum
            )
            teacher_attn = token_mask / robust_mask_sum
            # TODO Set the last dimension equal to student attention if mask_sum is 0
            if self.model_args.rg_align:
                teacher_attn = torch.where(mask_sum == 0, attention, teacher_attn)
        return teacher_attn

    def get_shift_teacher_attn(self, attentions):
        avg_attention = (
            sum(attentions[: self.model_args.rg_shift_layer])
            / self.model_args.rg_shift_layer
        ).mean(dim=1)
        # Set diagonal in last two dimension of (bsz, seq_len, seq_len) to 0
        avg_attention = avg_attention.masked_fill(
            (torch.eye(avg_attention.shape[-1], device=avg_attention.device) == 1)[
                None, :, :
            ],
            0,
        )
        avg_attention = avg_attention[:, 1:, :]
        # Padding last column
        avg_attention = avg_attention / avg_attention.sum(dim=-1, keepdim=True)
        avg_attention = torch.cat(
            [avg_attention, torch.zeros_like(avg_attention[:, 0:1, :])], dim=-2
        )
        return avg_attention

    def compute_loss(self, model, inputs, return_outputs=False):
        input_ids = inputs.get("input_ids")
        attention_mask = inputs.get("attention_mask")
        labels = inputs.get("labels")
        k_token_mask = inputs.get("k_token_mask")
        q_token_mask = inputs.get("q_token_mask")
        tgt_index = inputs.get("tgt_index")

        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            output_attentions=True,
            return_dict=True,
        )
        attentions = outputs.attentions
        attn_loss = 0.0
        if self.model_args.rg_distill_loss == "cl":
            for i, attention in enumerate(attentions):
                attn_loss += self.cl_loss(attention, (k_token_mask != 0).int())
        elif self.model_args.rg_teacher_type == "shift":
            with torch.no_grad():
                teacher_attn = self.get_shift_teacher_attn(attentions)
                cur_expanded_k_token_mask = k_token_mask[:, None, None, :].expand(
                    -1, attentions[0].shape[1], k_token_mask.shape[-1], -1
                )
                ele_mask = cur_expanded_k_token_mask != 0
                ele_mask[:, :, -1, :] = False
            for i, attention in enumerate(attentions[self.model_args.rg_shift_layer :]):
                attn_loss += self.get_distill_loss(
                    attention,
                    teacher_attn[:, None, :, :]
                    .expand(-1, attentions[0].shape[1], -1, -1)
                    .detach(),
                    ele_mask,
                )
        elif "base" in self.model_args.rg_distill_loss:
            for i, attention in enumerate(attentions):
                rg_distill_loss = self.model_args.rg_distill_loss.replace("base_", "")
                if rg_distill_loss == "entropy_norm":
                    attn_loss += entropy_loss(attention, None)
                elif rg_distill_loss == "entropy_norm_v2":
                    attn_loss += entropy_loss_v2(attention, None)
                elif rg_distill_loss == "entropy_norm_v3":
                    attn_loss += entropy_loss_v3(attention, None)
                elif rg_distill_loss == "p2_norm":
                    attn_loss += negative_p_norm_loss(attention, 2, None)
                elif rg_distill_loss == "pmax_norm":
                    attn_loss += negative_p_norm_loss(attention, float("inf"), None)
        else:
            with torch.no_grad():
                do_two_forward = (
                    model.module.config.amplify_total_topk is not None
                    or model.module.config.amplify_uncert_threshold is not None
                )
                if do_two_forward:
                    total_attn_weights = (sum(attentions) / len(attentions)).mean(dim=1)
                    k_token_mask = model.module._prepare_expanded_token_mask(
                        k_token_mask,
                        total_attn_weights,
                        input_ids.shape[0],
                        total_attn_weights.shape[-2],
                        total_attn_weights.shape[-1],
                        tgt_index,
                        outputs[0],
                    )

            for i, attention in enumerate(attentions):
                with torch.no_grad():
                    prepare_token_mask_fn = (
                        model.module.model.layers[
                            i
                        ].self_attn._prepare_expanded_token_mask
                        if self.model_args.amplify_soft_mask is not None
                        else model.module.model.layers[
                            i
                        ].self_attn._prepare_expanded_token_mask_soft
                    )
                    cur_expanded_k_token_mask = prepare_token_mask_fn(
                        k_token_mask,  # dimension: (bsz, seq_len)
                        attention,  # dimension: (bsz, num_heads, q_len, kv_seq_len)
                        q_token_mask,  # dimension: (bsz, seq_len)
                        input_ids.shape[0],
                        input_ids.shape[-1],
                        input_ids.shape[-1],
                        tgt_index,
                    )
                    if "norm" not in self.model_args.rg_distill_loss:
                        teacher_attn = self.get_teacher_attn(
                            attention, cur_expanded_k_token_mask
                        )
                ele_mask = cur_expanded_k_token_mask != 0
                if self.model_args.rg_distill_loss == "entropy_norm":
                    attn_loss += entropy_loss(attention, ele_mask)
                elif self.model_args.rg_distill_loss == "entropy_norm_v2":
                    attn_loss += entropy_loss_v2(attention, ele_mask)
                elif self.model_args.rg_distill_loss == "entropy_norm_v3":
                    attn_loss += entropy_loss_v3(attention, ele_mask)
                elif self.model_args.rg_distill_loss == "p2_norm":
                    attn_loss += negative_p_norm_loss(attention, 2, ele_mask)
                elif self.model_args.rg_distill_loss == "pmax_norm":
                    attn_loss += negative_p_norm_loss(attention, float("inf"), ele_mask)
                else:
                    attn_loss += self.get_distill_loss(
                        attention,
                        teacher_attn.detach(),
                        ele_mask,
                    )

        loss = outputs.loss + self.model_args.rg_weight * attn_loss
        print(
            {
                "lm": outputs.loss.item(),
                "distill": attn_loss.item(),
                "total": loss.item(),
            }
        )
        return (loss, {"outputs": outputs}) if return_outputs else loss
