from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
import torch.distributed as dist
from transformers import LlamaForCausalLM, LlamaModel
from transformers.modeling_outputs import CausalLMOutputWithPast

from src.common.templates import DATA_TYPE_DICT
from src.model.modeling_utils import concat_all_gather


def get_embedding(embed, x):
    if isinstance(x, torch.LongTensor):
        return F.embedding(
            x,
            embed.weight,
            embed.padding_idx,
            embed.max_norm,
            embed.norm_type,
            embed.scale_grad_by_freq,
            embed.sparse,
        )
    else:
        # Gumbel-Softmax
        return torch.matmul(x, embed.weight)


def freeze_model(model):
    for param in model.parameters():
        param.requires_grad = False


class QTransBaseLlamaForCausalLM(LlamaForCausalLM):
    def __init__(self, config, *model_args, **model_kwargs) -> None:
        super().__init__(config)

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        outputs = super().forward(
            input_ids,
            attention_mask,
            position_ids,
            past_key_values,
            inputs_embeds,
            None,  # labels
            use_cache,
            output_attentions,
            output_hidden_states,
            return_dict,
        )

        logits = outputs.logits
        loss = None
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(
                logits.view(-1, self.config.vocab_size),
                labels.view(-1).to(logits.device),
            )

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


class QTransLlamaForCausalLM(LlamaForCausalLM):
    # Just for training and evaluation
    def __init__(self, config, *model_args, **model_kwargs) -> None:
        super().__init__(config)
        self.model_args = model_kwargs["model_args"]
        if self.model_args.trans_self:
            self.trans_model = self.model
        elif model_kwargs["is_eval"]:
            self.trans_model = LlamaForCausalLM(model_kwargs["trans_config"])
        else:
            self.trans_model = None
        # self.trans_model = (
        #     LlamaForCausalLM(model_kwargs["trans_config"])
        #     if not self.model_args.trans_self
        #     else self.model
        # )
        # self.trans_model = None if not self.model_args.trans_self else self.model
        if self.model_args.trans_freeze:
            if not self.model_args.trans_self:
                # Freeze the main model
                freeze_model(self.model)
            else:
                raise NotImplementedError()
        if self.model_args.trans_sep_lm_head:
            self.trans_lm_head = nn.Linear(
                config.hidden_size, config.vocab_size, bias=False
            )
        else:
            self.trans_lm_head = None

    def set_gumbel_temperature(self, tau):
        self.tau = tau

    # def from_pretrained_merge(self, model_path, trans_model_path):
    #     self.model.from_pretrained(model_path)
    #     self.trans_model.from_pretrained(trans_model_path)

    def forward(
        self,
        input_ids: torch.LongTensor = None,  # [start, question, padding, output, end]
        question_ids: Optional[torch.LongTensor] = None,  # Add start index
        repeat_starts: Optional[torch.LongTensor] = None,
        repeat_ends: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        return_input_type: Optional[str] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        # Transform the question into one-hot embedding
        if question_ids is not None:
            question_hidden_states = self.trans_model.model(question_ids)[0]
            lm_head = (
                self.trans_lm_head
                if self.trans_lm_head is not None
                else self.trans_model.lm_head
            )
            question_logits = lm_head(question_hidden_states)[:, 1:]
            # if self.model_args.copy_question:
            #     loss_fct = CrossEntropyLoss()
            #     question_labels = question_ids.clone()
            #     question_labels[question_labels == 0] = -100
            #     copy_loss = loss_fct(
            #         question_logits.view(-1, self.config.vocab_size),
            #         question_labels.view(-1).to(question_logits.device),
            #     )
            # else:
            if self.model_args.trans_softmax:
                question_trans = F.softmax(question_logits / self.tau, dim=-1)
            elif return_input_type is not None and return_input_type != "gumbel":
                question_trans = F.softmax(question_logits, dim=-1)
            else:
                question_trans = F.gumbel_softmax(
                    question_logits,
                    tau=self.tau,
                    hard=self.model_args.trans_hard,
                    dim=-1,
                )
            # TODO: not work for left padding
            if return_input_type == "id":
                qtrans_ids = torch.argmax(question_trans, dim=-1)
                for i, (s, e) in enumerate(zip(repeat_starts, repeat_ends)):
                    input_ids[i, s:e] = qtrans_ids[i, : (e - s)]
                return input_ids
            # Get all inputs_embeds and remove bos
            question_embeds = get_embedding(self.model.embed_tokens, question_trans)
            inputs_embeds = self.model.embed_tokens(input_ids)
            for i, (s, e) in enumerate(zip(repeat_starts, repeat_ends)):
                inputs_embeds[i, s:e] = question_embeds[i][: (e - s)]
            if return_input_type is not None:
                return inputs_embeds
            input_ids = None
        outputs = super().forward(
            input_ids,
            attention_mask,
            position_ids,
            past_key_values,
            inputs_embeds,
            labels,
            use_cache,
            output_attentions,
            output_hidden_states,
            return_dict,
        )
        # if self.model_args.copy_question:
        # outputs.loss = outputs.loss + copy_loss
        return outputs
