from typing import Callable, Dict, List, Optional, Tuple, Union, Any
import torch
from torch import nn
from torch.utils.data import Dataset
from transformers import Trainer
from transformers.data.data_collator import DataCollator
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalPrediction
from transformers.training_args import TrainingArguments


class ReasonGraphTrainer(Trainer):
    def __init__(self, model_args=None, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.model_args = model_args

    def kl_loss(self, student_attn, teacher_attn, ele_mask, epsilon=1e-10):
        kl_loss_fn = nn.KLDivLoss(reduction="none")
        # Handle NaN
        teacher_attn = teacher_attn.clamp(min=epsilon)
        student_attn = student_attn.clamp(min=epsilon)
        teacher_attn = teacher_attn / teacher_attn.sum(dim=-1, keepdim=True)
        student_attn = student_attn / student_attn.sum(dim=-1, keepdim=True)
        loss = kl_loss_fn(student_attn.log(), teacher_attn)
        return loss

    def mse_loss(self, student_attn, teacher_attn, ele_mask):
        mse_loss_fn = nn.MSELoss(reduction="none")
        loss = mse_loss_fn(student_attn, teacher_attn)
        return loss

    def cl_loss(self, attn_weights, token_mask):
        attn_weights_flat = attn_weights.view(-1, attn_weights.size(-1))
        mask_flat = token_mask.reshape(-1, token_mask.size(-1))
        mask_flat_sum = mask_flat.sum(dim=-1)
        valid_indices = mask_flat_sum.nonzero()
        attn_weights_valid = attn_weights_flat[valid_indices]
        positive_attn_weights = attn_weights_valid * mask_flat[valid_indices]
        positive_attn_weights = positive_attn_weights.sum(dim=-1)
        neg_log_attn_weights = -torch.log(positive_attn_weights + 1e-9)

        # attn_weights_flat = attn_weights.view(-1, attn_weights.size(-1))
        # mask_flat = token_mask.reshape(-1, token_mask.size(-1))
        # mask_flat_sum = mask_flat.sum(dim=-1)
        # attn_weights_valid = attn_weights_flat[mask_flat_sum != 0]
        # positive_attn_weights = attn_weights_valid.sum(dim=-1)
        # neg_log_attn_weights = -torch.log(positive_attn_weights)
        return neg_log_attn_weights.mean()

        # log_attn_weights = -torch.log(attn_weights_valid + 1e-9)
        # mask_sum = mask_flat[mask_flat != 0].sum(dim=-1, keepdim=True)
        # loss = log_attn_weights.sum(dim=-1) / mask_sum
        # loss = loss.mean()
        # log_attn_weights = torch.log(attn_weights_flat + 1e-9)
        # positive_neg_log_probs = -(log_attn_weights * mask_flat)
        # mask_sum = mask_flat.sum(dim=-1, keepdim=True)
        # robust_mask_sum = torch.where(
        #     mask_sum == 0, torch.ones_like(mask_sum), mask_sum
        # )
        # loss = positive_neg_log_probs.sum(dim=-1) / robust_mask_sum
        # loss = loss.mean()
        # return loss

    def get_distill_loss(self, student_attn, teacher_attn, ele_mask):
        if self.model_args.rg_distill_loss == "kl":
            loss = self.kl_loss(student_attn, teacher_attn, ele_mask)
        elif self.model_args.rg_distill_loss == "mse":
            loss = self.mse_loss(student_attn, teacher_attn, ele_mask)
        if self.model_args.rg_batch_avg:
            # first average in each batch
            loss = (loss * ele_mask).view(loss.shape[0], -1).sum(
                dim=-1
            ) / ele_mask.view(loss.shape[0], -1).sum(dim=-1)
            loss = loss.mean()
        else:
            loss = (loss * ele_mask).sum() / ele_mask.sum()
        return loss

    def get_teacher_attn(self, attention, token_mask):
        # KL divergence between student attention and teacher attention
        if self.model_args.rg_teacher_type == "reweight":
            teacher_attn = attention * (1 + token_mask)
            teacher_attn = teacher_attn / teacher_attn.sum(dim=-1, keepdim=True)
        elif self.model_args.rg_teacher_type == "oracle":
            mask_sum = token_mask.sum(dim=-1, keepdim=True)
            robust_mask_sum = torch.where(
                mask_sum == 0, torch.ones_like(mask_sum), mask_sum
            )
            teacher_attn = token_mask / robust_mask_sum
            # TODO Set the last dimension equal to student attention if mask_sum is 0
            if self.model_args.rg_align:
                teacher_attn = torch.where(mask_sum == 0, attention, teacher_attn)
        return teacher_attn

    def compute_loss(self, model, inputs, return_outputs=False):
        input_ids = inputs.get("input_ids")
        attention_mask = inputs.get("attention_mask")
        labels = inputs.get("labels")
        token_mask = inputs.get("k_token_mask")

        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            output_attentions=True,
            return_dict=True,
        )
        attentions = outputs.attentions
        if token_mask.dim() == 3:
            token_mask = token_mask.unsqueeze(1)
            token_mask = token_mask.expand(-1, attentions[0].size(1), -1, -1)
        else:
            bsz = token_mask.shape[0]
            q_len = kv_seq_len = token_mask.shape[1]
            token_mask = (
                token_mask[:, None, None, :]
                .expand(bsz, attentions[0].size(1), q_len, kv_seq_len)
                .to(token_mask.dtype)
            )
        # if self.model_args.rg_teacher_type == "reweight":
        ele_mask = (token_mask != 0).to(
            dtype=token_mask.dtype, device=token_mask.device
        )
        # else:
        #     ele_mask = torch.ones_like(token_mask)
        attn_loss = 0.0
        for attention in attentions:
            if self.model_args.rg_distill_loss == "cl":
                attn_loss += self.cl_loss(attention, ele_mask)
            else:
                teacher_attn = self.get_teacher_attn(attention, token_mask)
                attn_loss += self.get_distill_loss(
                    attention, teacher_attn.detach(), ele_mask
                )

        loss = outputs.loss + self.model_args.rg_weight * attn_loss
        print(
            {
                "lm": outputs.loss.item(),
                "distill": attn_loss.item(),
                "total": loss.item(),
            }
        )
        return (loss, {"outputs": outputs}) if return_outputs else loss
