from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
from .utils import MoLoRALinear

from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer, Qwen2MLP, Qwen2Model, Qwen2ForCausalLM
from ming.model.modeling_internlm2 import InternLM2DecoderLayer, InternLM2MLP, InternLM2Model, InternLM2ForCausalLM, _import_flash_attn
from ming.model.configuration_internlm2 import InternLM2Config
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
from transformers.modeling_outputs import ModelOutput, BaseModelOutputWithPast
import torch.nn as nn
import torch 
import torch.nn.functional as F 
from typing import Optional, Tuple, Union, List
import warnings
from transformers.utils import logging
from dataclasses import dataclass


logger = logging.get_logger(__name__)



@dataclass
class BaseModelOutputWithPastLogitLoss(ModelOutput):
    """
    Base class for model's outputs, with past key value states and logit bias.
    """

    last_hidden_state: torch.FloatTensor = None
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
    logit_loss: Optional[torch.FloatTensor] = None

class MoLoRAQwenMLPDeploy(Qwen2MLP):
    def __init__(self, config):
        super().__init__(config)
        params = {
            "r": config.r,
            "lora_alpha": config.lora_alpha,
            "lora_dropout": config.lora_dropout,
            "num_experts": config.num_experts,
            "num_experts_per_token": config.num_experts_per_token,
            "share_expert": getattr(config, "share_expert", False),
            "expert_sampling": True if config.expert_selection == 'sampling' else False,
            "use_rslora": getattr(config, "use_rslora", False),
            "use_logit_sum": getattr(config, "output_logit_loss", False),
        }
        self.use_logit_sum = params['use_logit_sum']
        self.gate_proj = MoLoRALinear(self.hidden_size, self.intermediate_size, bias=False,
                                      **params)
        self.up_proj = MoLoRALinear(self.hidden_size, self.intermediate_size, bias=False, **params)
        self.down_proj = MoLoRALinear(self.intermediate_size, self.hidden_size, bias=False, **params)
    
    def forward(self, x):
        if self.use_logit_sum:
            gate_output, gate_logit_sum = self.gate_proj(x)
            up_output, up_logit_sum = self.up_proj(x)
            down_output, down_logit_sum = self.down_proj(self.act_fn(gate_output) * up_output)
            # NOTE: current average the three weights logit sum, may use a nontrivial way
            # leave for future research space
            logit_sum = (gate_logit_sum + up_logit_sum + down_logit_sum) / 3
            return down_output, logit_sum
        else:
            return super().forward(x)
        
class MoLoRAInternLM2MLP(InternLM2MLP):
    def __init__(self, config):
        super().__init__(config)
        # params = {
        #     "r": config.r,
        #     "lora_alpha": config.lora_alpha,
        #     "lora_dropout": config.lora_dropout,
        #     "num_experts": config.num_experts,
        #     "num_experts_per_token": config.num_experts_per_token,
        #     "share_expert": getattr(config, "share_expert", False),
        #     "expert_sampling": True if config.expert_selection == 'sampling' else False,
        #     "use_rslora": getattr(config, "use_rslora", False),
        #     "use_logit_sum": getattr(config, "output_logit_loss", False),
        # }
        self.output_logit_loss = getattr(config, "output_logit_loss", False)
        
        # 1 for absolute loss, 2 for relative loss
        self.LOGIT_GROUPING = {
            1: lambda x, y, z: (x + y + z) / 3,
            2: lambda x, y, z: torch.stack([x, y, z], dim=0)
        }
        # self.gate_proj = MoLoRALinear(self.hidden_size, self.intermediate_size, bias=False,
        #                               **params)
        # self.up_proj = MoLoRALinear(self.hidden_size, self.intermediate_size, bias=False, **params)
        # self.down_proj = MoLoRALinear(self.intermediate_size, self.hidden_size, bias=False, **params)
    
    def forward(self, x):
        if self.output_logit_loss:
            gate_output, gate_logit_sum = self.w1(x)
            up_output, up_logit_sum = self.w3(x)
            down_output, down_logit_sum = self.w2(self.act_fn(gate_output) * up_output)
            # NOTE: current stack the logit sum
            grouping_func = self.LOGIT_GROUPING[self.output_logit_loss]
            logit_sum = grouping_func(gate_logit_sum, up_logit_sum, down_logit_sum)

            return down_output, logit_sum
        else:
            return super().forward(x)

class MoLoRAInternLM2DecoderLayer(InternLM2DecoderLayer):
    def __init__(self, config: InternLM2Config, layer_idx: int):
        super().__init__(config, layer_idx)
        self.feed_forward = MoLoRAInternLM2MLP(config)
        
        self.config = config
    
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = False,
        output_logit_loss: Optional[bool] = False,
        **kwargs,
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
        if "padding_mask" in kwargs:
            warnings.warn(
                "Passing `padding_mask` is deprecated and will be removed in v4.37. "
                "Please make sure use `attention_mask` instead.`"
            )
        """
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
                `(batch, sequence_length)` where padding elements are indicated by 0.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            use_cache (`bool`, *optional*):
                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
                (see `past_key_values`).
            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
        """

        residual = hidden_states

        hidden_states = self.attention_norm(hidden_states)

        # Self Attention
        hidden_states, self_attn_weights, present_key_value = self.attention(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_value=past_key_value,
            output_attentions=output_attentions,
            use_cache=use_cache,
            **kwargs
        )
        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = self.ffn_norm(hidden_states)
        if output_logit_loss:
            hidden_states, logit_sum = self.feed_forward(hidden_states)
        else:
            hidden_states = self.feed_forward(hidden_states)
        hidden_states = residual + hidden_states

        outputs = (hidden_states,)

        if output_attentions:
            outputs += (self_attn_weights,)

        if use_cache:
            outputs += (present_key_value,)

        if output_logit_loss:
            outputs += (logit_sum,)
        return outputs
    

class MoLoRAInternLM2Model(InternLM2Model):
    def __init__(self, config: InternLM2Config):
        super().__init__(config)
        self.layers = nn.ModuleList(
            [MoLoRAInternLM2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
        )
    
    # @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        output_logit_loss: Optional[bool] = None,
    ) -> Union[Tuple, Union[BaseModelOutputWithPastLogitLoss, BaseModelOutputWithPast]]:
        output_logit_loss = output_logit_loss if output_logit_loss is not None else getattr(self.config, "output_logit_loss", 0)
        
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if self.config.attn_implementation == "flash_attention_2":
            _import_flash_attn()

        # retrieve input_ids and inputs_embeds
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            batch_size, seq_length = input_ids.shape[:2]
        elif inputs_embeds is not None:
            batch_size, seq_length = inputs_embeds.shape[:2]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        seq_length_with_past = seq_length
        past_key_values_length = 0
        if past_key_values is not None:
            past_key_values_length = past_key_values[0][0].shape[2]
            seq_length_with_past = seq_length_with_past + past_key_values_length

        if position_ids is None:
            device = input_ids.device if input_ids is not None else inputs_embeds.device
            position_ids = torch.arange(
                past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
            )
            position_ids = position_ids.unsqueeze(0)

        if inputs_embeds is None:
            inputs_embeds = self.tok_embeddings(input_ids)

        if self.config.attn_implementation == "flash_attention_2":
            # 2d mask is passed through the layers
            attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
        else:
            if attention_mask is None:
                attention_mask = torch.ones(
                    (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
                )
            attention_mask = self._prepare_decoder_attention_mask(
                attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
            )

        # embed positions
        hidden_states = inputs_embeds

        if self.gradient_checkpointing and self.training:
            if use_cache:
                logger.warning_once(
                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
                )
                use_cache = False

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        all_router_logits = () if output_logit_loss else None
        next_decoder_cache = () if use_cache else None

        for idx, decoder_layer in enumerate(self.layers):
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            past_key_value = past_key_values[idx] if past_key_values is not None else None

            if self.gradient_checkpointing and self.training:

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        # None for past_key_value
                        return module(*inputs, output_attentions, None, output_logit_loss)

                    return custom_forward

                layer_outputs = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(decoder_layer),
                    hidden_states,
                    attention_mask,
                    position_ids,
                    None,
                )
            else:
                layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=attention_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_value,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                    output_logit_loss=output_logit_loss
                )

            hidden_states = layer_outputs[0]

            if use_cache:
                next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)

            if output_attentions:
                all_self_attns += (layer_outputs[1],)
            
            if output_logit_loss:
                all_router_logits += (layer_outputs[-1],)

        hidden_states = self.norm(hidden_states)

        # add hidden states from the last decoder layer
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        next_cache = next_decoder_cache if use_cache else None

        if not return_dict:
            return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] if v is not None)
        return BaseModelOutputWithPastLogitLoss(
            last_hidden_state=hidden_states,
            past_key_values=next_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
            logit_loss=all_router_logits,
        )
        

class MoLoRAInternLM2ForCausalLM(InternLM2ForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        self.model = MoLoRAInternLM2Model(config)
    
