from typing import List, Tuple
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn import CrossEntropyLoss
from torch import LongTensor, Tensor, FloatTensor
from transformers import LlamaForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast


class LlamaForCausalLMChoiceCls(LlamaForCausalLM):
    def __init__(self, config, *model_args, **model_kargs):
        super().__init__(config)
        self.model_args = model_kargs["model_args"]
        self.num_labels = 2
        self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
        self.post_init()

    def cls_loss(self, scores, cls_labels):
        loss_fct = CrossEntropyLoss()
        loss = loss_fct(scores.view(-1, self.num_labels), cls_labels.view(-1))
        return loss

    def get_pooled_logits(self, input_ids, hidden_states, scorer):
        cls_logits = scorer(hidden_states)
        batch_size = hidden_states.shape[0]
        if self.config.pad_token_id is None and batch_size != 1:
            raise ValueError(
                "Cannot handle batch sizes > 1 if no padding token is defined."
            )
        if self.config.pad_token_id is None:
            sequence_lengths = -1
        else:
            sequence_lengths = (
                torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
            ).to(cls_logits.device)
        pooled_logits = cls_logits[
            torch.arange(batch_size, device=cls_logits.device), sequence_lengths
        ]
        return pooled_logits

    def forward(
        self,
        input_ids: LongTensor = None,
        attention_mask: Tensor | None = None,
        input_mask: Tensor | None = None,
        choice_rewards: FloatTensor | None = None,
        position_ids: LongTensor | None = None,
        past_key_values: List[FloatTensor] | None = None,
        inputs_embeds: FloatTensor | None = None,
        labels: LongTensor | None = None,
        use_cache: bool | None = None,
        output_attentions: bool | None = None,
        output_hidden_states: bool | None = None,
        return_dict: bool | None = None,
    ) -> Tuple | CausalLMOutputWithPast:
        outputs = super().forward(
            input_ids,
            attention_mask,
            position_ids,
            past_key_values,
            inputs_embeds,
            None,
            use_cache,
            output_attentions,
            True,
            return_dict,
        )
        logits = None
        if labels is not None:
            logits = outputs["logits"]
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            shift_labels = shift_labels.to(shift_logits.device)
            loss_fct = CrossEntropyLoss()
            sample_mask = choice_rewards > 0
            shift_logits = shift_logits[sample_mask]
            shift_labels = shift_labels[sample_mask]
            lm_loss = loss_fct(
                shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)
            )
            if self.model_args.cls_weight > 0.0:
                logits = outputs["logits"]
                scores = self.get_pooled_logits(
                    input_mask, outputs.hidden_states[-1], self.score
                )
                cls_loss = self.cls_loss(scores, choice_rewards)
                if dist.get_rank() == 0:
                    print({"lm_loss": lm_loss.item(), "cls": cls_loss.item()})
                loss = lm_loss + self.model_args.cls_weight * cls_loss
            else:
                loss = lm_loss
        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
