from transformers import LlamaForCausalLM


class LlamaForCausalLMRepeat(LlamaForCausalLM):
    def prepare_inputs_for_generation(
        self,
        input_ids,
        past_key_values=None,
        attention_mask=None,
        inputs_embeds=None,
        **kwargs
    ):
        if past_key_values:
            input_ids = input_ids[:, -1:]

        position_ids = kwargs.get("position_ids", None)
        if attention_mask is not None and position_ids is None:
            # norm_attention_mask = attention_mask.masked_fill(attention_mask == -1, 1)
            # # create position_ids on the fly for batch generation
            # position_ids = norm_attention_mask.long().cumsum(-1) - 1
            # if past_key_values:
            #     attention_mask.masked_fill_(attention_mask == -1, 0)
            # else:
            #     attention_mask = norm_attention_mask
            # position_ids.masked_fill_(attention_mask == 0, 1)
            # if past_key_values:
            #     position_ids = position_ids[:, -1].unsqueeze(-1)

            # norm_attention_mask = attention_mask.masked_fill(attention_mask == -1, 1)
            norm_attention_mask = attention_mask.masked_fill(attention_mask == -1, 0)
            if past_key_values:
                input_attention_mask = attention_mask.masked_fill(
                    attention_mask == -1, 0
                )
                # attention_mask.masked_fill_(attention_mask == -1, 0)
            else:
                input_attention_mask = norm_attention_mask
                # attention_mask.masked_fill_(attention_mask == -1, 1)
            # create position_ids on the fly for batch generation
            # position_ids = norm_attention_mask.long().cumsum(-1) - 1
            position_ids = (
                attention_mask.masked_fill(attention_mask == -1, 1).long().cumsum(-1)
                - 1
            )
            position_ids.masked_fill_(input_attention_mask == 0, 1)
            if past_key_values:
                position_ids = position_ids[:, -1].unsqueeze(-1)

        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
        if inputs_embeds is not None and past_key_values is None:
            model_inputs = {"inputs_embeds": inputs_embeds}
        else:
            model_inputs = {"input_ids": input_ids}

        model_inputs.update(
            {
                "position_ids": position_ids,
                "past_key_values": past_key_values,
                "use_cache": kwargs.get("use_cache"),
                "attention_mask": input_attention_mask,
            }
        )
        return model_inputs
