from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
import torch.distributed as dist
from torch import LongTensor, Tensor, FloatTensor
from transformers import LlamaForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast

from src.common.templates import DATA_TYPE_DICT
from src.model.modeling_utils import concat_all_gather
from src.model.modeling_rank import AlignmentLlamaForCausalLM


class AlignmentComLlamaForCausalLM(AlignmentLlamaForCausalLM):
    def __init__(self, config, *model_args, **model_kargs):
        super().__init__(config, *model_args, **model_kargs)
        self.num_labels = 2
        if (
            self.model_args.rank_com_pooler_type == "hidden_state"
            or self.model_args.rank_com_type == "cls"
        ):
            self.com_score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
        self.post_init()

    def forward(
        self,
        input_ids: LongTensor = None,
        attention_mask: Tensor | None = None,
        rewards: FloatTensor | None = None,
        strategy_mask: Tensor | None = None,
        strategy_ids: Tensor | None = None,
        strategy_rewards: FloatTensor | None = None,
        position_ids: LongTensor | None = None,
        past_key_values: List[FloatTensor] | None = None,
        inputs_embeds: FloatTensor | None = None,
        labels: LongTensor | None = None,
        use_cache: bool | None = None,
        output_attentions: bool | None = None,
        output_hidden_states: bool | None = None,
        return_dict: bool | None = None,
    ) -> Tuple | CausalLMOutputWithPast:
        outputs = super().forward(
            input_ids,
            attention_mask,
            rewards,
            position_ids,
            past_key_values,
            inputs_embeds,
            labels,
            use_cache,
            output_attentions,
            output_hidden_states,
            return_dict,
        )
        logits = None
        if labels is not None and self.model_args.rank_com_weight > 0.0:
            logits = outputs["logits"]
            if self.model_args.rank_com_type != "cls":
                if self.model_args.rank_com_pooler_type == "seq_prob":
                    shift_logits = logits[strategy_ids][..., :-1, :].contiguous()
                    shift_labels = labels[strategy_ids][..., 1:].contiguous()
                    shift_labels = shift_labels.to(shift_logits.device)
                    scores = self.get_sequence_prob_score(
                        input_ids[strategy_ids], shift_logits, shift_labels
                    )
                else:
                    scores = self.get_eos_score(
                        strategy_mask,
                        outputs.hidden_states[-1][strategy_ids],
                        self.com_score,
                    )
            else:
                scores = self.get_pooled_logits(
                    strategy_mask,
                    outputs.hidden_states[-1][strategy_ids],
                    self.com_score,
                )
            if self.model_args.rank_com_type == "stable":
                rank_loss = self.stable_alignment(scores, strategy_rewards)
            elif self.model_args.rank_com_type == "pangu":
                rank_loss = self.pangu_loss(scores, strategy_rewards)
            elif self.model_args.rank_com_type == "cross":
                rank_loss = self.listwise_loss(scores, strategy_rewards)
            elif self.model_args.rank_com_type == "cls":
                rank_loss = self.cls_loss(scores, strategy_rewards)
            if dist.get_rank() == 0:
                print({"rank_com": rank_loss.item()})
            loss = outputs.loss + self.model_args.rank_weight * rank_loss

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
