from typing import List, Optional, Tuple, Union
import torch
from transformers import LlamaForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast

from src.model.loss import ContrastiveCrossEntropyLoss


class CringeLlamaForCausalLM(LlamaForCausalLM):
    def __init__(self, config, *model_args, **model_kargs):
        super().__init__(config)
        self.model_args = model_kargs["model_args"]

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        classifier_labels=None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        outputs = super().forward(
            input_ids,
            attention_mask,
            position_ids,
            past_key_values,
            inputs_embeds,
            # labels,
            None,
            use_cache,
            output_attentions,
            output_hidden_states,
            return_dict,
        )

        def loss_reshape(loss):
            return loss.view(logits.size(0), logits.size(1) - 1).sum(dim=1)

        logits = outputs.logits
        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = ContrastiveCrossEntropyLoss(
                ct_loss_weight=1.0,
                num_pos_predictions=5,
                detach_positives_during_ct=False,  # TODO
                train_ct_on_positive_examples=True
                if self.model_args.cringe_type == "gsm"
                else False,
                train_ce_on_positive_examples=True,
                reduction="none",
                add_reduce_loss=self.model_args.cringe_add_reduce_loss
            )
            # shift_logits = shift_logits.view(-1, self.config.vocab_size)
            # shift_labels = shift_labels.view(-1)
            # Enable model parallelism
            shift_labels = shift_labels.to(shift_logits.device)
            if len(classifier_labels.shape) == 1:
                classifier_labels = classifier_labels.unsqueeze(1).repeat(
                    1, shift_labels.shape[-1]
                )
            else:
                classifier_labels = classifier_labels[..., 1:].contiguous()
            (loss, ce_loss, ct_loss, ce_mask, ct_mask) = loss_fct(
                shift_logits.view(-1, self.config.vocab_size),
                shift_labels.view(-1),
                classifier_labels=classifier_labels.view(-1),
            )
            ce_loss = loss_reshape(ce_loss)
            ct_loss = loss_reshape(ct_loss)
            notnull = shift_labels.ne(-100)
            ce_mask = torch.logical_and(
                notnull, ce_mask.view(-1, shift_labels.shape[-1])
            )
            ct_mask = torch.logical_and(
                notnull, ct_mask.view(-1, shift_labels.shape[-1])
            )
            # number of tokens in each examples for cross entropy or cringe loss.
            metric_notnull = torch.logical_or(ce_mask, ct_mask)
            target_tokens = metric_notnull.long().sum(dim=-1)
            ce_target_tokens = ce_mask.long().sum(dim=-1)
            ct_target_tokens = ct_mask.long().sum(dim=-1)
            ce_loss = ce_loss.sum() / ce_target_tokens.sum()
            ct_loss = ct_loss.sum() / ct_target_tokens.sum()

            if self.model_args.cringe_add_reduce_loss:
                loss = ce_loss + ct_loss
            else:
                loss = loss_reshape(loss)
                loss = loss.sum()
                loss /= target_tokens.sum()

            print(
                {
                    "loss": loss.item(),
                    "ce_loss": ce_loss.item(),
                    "ct_loss": ct_loss.item(),
                }
            )

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
