from dataclasses import dataclass
import warnings
from typing import Dict, List, Any, Optional, Union

import torch
from torch import nn
from transformers import PreTrainedModel, PreTrainedTokenizer, GenerationMixin
from transformers.generation.logits_process import (
    TemperatureLogitsWarper,
    TopKLogitsWarper,
    TopPLogitsWarper,
    LogitsProcessorList,
)
from transformers.generation.stopping_criteria import (
    MaxLengthCriteria,
    MaxNewTokensCriteria,
    StoppingCriteriaList,
)

from inference_time_alignment.scorer import BaseScorer
from inference_time_alignment.utils import StopOnStringCriteria, extract_responses, get_truncated_responses


@dataclass(kw_only=True)
class EFTPosthocGenerationMixin(GenerationMixin):
    """
    args: 
        `base`: 
            the base language model to be steered at inference time.
        `tune_r`, `base_r`: 
            `tune_r` are a list of N language models fine-tuned from `ref_r`, \
            Each model from `tune_r` can be combined with `ref_r` to define a implicit reward models r_i = logp_{tune_r[i]} - logp_{base_r}.
        `w`: 
            the linear weights for combining the implicit reward models r = sum_{i=1}^N r_i.

    The resulting sampling distribution is proportioal to 
          logp_{base} * r
        = logp_{base} * sum_{i=1}^{N}(w_i*r_i), 
    where r_i = logp_{tune_r[i]} - logp_{base_r}.
    """
    base:     PreTrainedModel
    tune_r:   List[PreTrainedModel] | PreTrainedModel
    base_r:   Optional[PreTrainedModel] = None
    w:        List[float] | float

    def __post_init__(self):
        if not isinstance(self.base, list):  self.tune_r = [self.tune_r]
        if not isinstance(self.w, list): self.w = [self.w]
        if not self.base_r: self.base_r = self.base
        assert len(self.tune_r) == len(self.w) 

    def __getattribute__(self, name: str) -> Any:
        try:
            return super().__getattribute__(name)
        except AttributeError:
            return getattr(self.base, name)

    def prepare_inputs_for_generation(self, input_ids, **model_kwargs):
        if 'past_key_values' in model_kwargs:
            past_key_values = model_kwargs['past_key_values']
            model_kwargs['past_key_values'] = model_kwargs['past_key_values']['base']
        result = self.base.prepare_inputs_for_generation(input_ids, **model_kwargs)
        if 'past_key_values' in model_kwargs:
            result['past_key_values'] = past_key_values
        return result

    def _parallel_decode(self, *args, past_key_values_list, **kwargs):
        models_outputs_list = []
        for i, model in enumerate(self.tune_r):
            outputs = model(*args, past_key_values=past_key_values_list[i], **kwargs)
            models_outputs_list.append(outputs)
        return models_outputs_list

    @torch.no_grad()
    def __call__(self, *args, past_key_values=None, **kwargs):
        if not past_key_values:
            past_key_values = {'base': None, 'tune_r': [None] * len(self.tune_r), 'base_r': None}

        base_outputs = self.base(*args, past_key_values=past_key_values['base'], **kwargs)
        base_logits  = base_outputs.logits[:, -1, :]

        if self.base == self.base_r:
            base_r_outputs = base_outputs
            base_r_logits  = base_logits
        else:
            base_r_outputs = self.base_r(*args, past_key_values=past_key_values['base_r'], **kwargs)
            base_r_logits  = base_r_outputs.logits[:, -1, :]

        tune_r_outputs_list = self._parallel_decode(*args, past_key_values_list=past_key_values['tune_r'], **kwargs)
        tune_r_logits_list  = [outputs.logits[:, -1, :] for outputs in tune_r_outputs_list]

        tune_r_logits = torch.stack(tune_r_logits_list, dim=-1)
        w = torch.tensor(self.w).to(tune_r_logits.device)
        r = (w * (tune_r_logits - base_r_logits.unsqueeze(-1))).sum(-1)

        dim_min = min(base_logits.size(1), r.size(1))
        logits = base_logits[:, :dim_min] + r[:, :dim_min]

        outputs = base_outputs
        outputs.logits = logits.unsqueeze(-2)
        outputs.past_key_values = {
            'base':   base_outputs.past_key_values,
            'tune_r': [outputs.past_key_values for outputs in tune_r_outputs_list],
            'base_r': base_r_outputs.past_key_values,
        }
        return outputs



@dataclass
class BeamTuningPosthocGenerationMixin(GenerationMixin):
    base: PreTrainedModel
    tokenizer: PreTrainedTokenizer

    def __getattribute__(self, name: str) -> Any:
        try:
            return super().__getattribute__(name)
        except AttributeError:
            return getattr(self.base, name)

    def prepare_inputs_for_generation(self, input_ids, **model_kwargs):
        return self.base.prepare_inputs_for_generation(input_ids, **model_kwargs)

    def _reorder_cache(self, past_key_values, beam_idx):
        return self.base._reorder_cache(past_key_values, beam_idx)

    @torch.no_grad()
    def bon_beam_sample(
        self,
        input_ids: torch.LongTensor,
        scorer: BaseScorer,
        max_new_tokens: Optional[int] = None,
        max_length: Optional[int] = None,
        temperature: Optional[float] = None,
        top_k: Optional[float] = None,
        top_p: Optional[float] = None,
        eos_strings: Optional[int] = None,
        split_by_prompt_text: Optional[bool] = True,
        **kwargs,
    ):
        logits_warper = []
        if temperature: logits_warper.append(TemperatureLogitsWarper(temperature))
        if top_k: logits_warper.append(TopKLogitsWarper(top_k))
        if top_p: logits_warper.append(TopPLogitsWarper(top_p))
        logits_warper = LogitsProcessorList(logits_warper)
        stopping_criteria = []
        if eos_strings:
            stopping_criteria.extend([StopOnStringCriteria(input_ids.size(1), eos_string, self.tokenizer) for eos_string in eos_strings])
        assert not (max_new_tokens and max_length)
        if max_length: stopping_criteria.append(MaxLengthCriteria(max_length=max_length))
        if max_new_tokens: stopping_criteria.append(MaxNewTokensCriteria(start_length=input_ids.size(1), max_new_tokens=max_new_tokens))
        stopping_criteria = StoppingCriteriaList(stopping_criteria)
        if not self.generation_config.pad_token_id:
            self.generation_config.pad_token_id = self.generation_config.eos_token_id
            if isinstance(self.generation_config.pad_token_id, list):
                self.generation_config.pad_token_id = self.generation_config.pad_token_id[0]
        kwargs.update({"use_cache": True})
        return self._bon_beam_sample(
            input_ids,
            scorer,
            stopping_criteria=stopping_criteria,
            logits_warper=logits_warper,
            eos_strings=eos_strings,
            split_by_prompt_text=split_by_prompt_text,
            **kwargs
        )

    @torch.no_grad()
    def _bon_beam_sample(
        self,
        input_ids: torch.LongTensor,
        scorer: BaseScorer,
        num_beams: Optional[int] = 4,
        num_candidates: Optional[int] = 4,
        block_len: Optional[int] = 10,
        logits_processor: Optional[LogitsProcessorList] = None,
        stopping_criteria: Optional[StoppingCriteriaList] = None,
        logits_warper: Optional[LogitsProcessorList] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[Union[int, List[int]]] = None,
        return_dict_in_generate: Optional[bool] = None,
        eos_strings: Optional[int] = None,
        split_by_prompt_text: Optional[bool] = True,
        **model_kwargs,
    ) -> Union[Dict, torch.LongTensor]:
        # init values
        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
        logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
        pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
        assert not (eos_strings and eos_token_id)
        if isinstance(eos_token_id, int):
            eos_token_id = [eos_token_id]
        eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None

        # repeat input_ids and attention_mask
        input_ids = input_ids.repeat(num_beams * num_candidates, 1)
        model_kwargs["attention_mask"] = model_kwargs["attention_mask"].repeat(num_beams * num_candidates, 1)

        # keep track of which sequences are already finished
        unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
        blocks_to_use = block_len
        prompt, prompt_len = self.tokenizer.decode(input_ids[0]), input_ids.size(1)

        this_peer_finished = False  # used by synced_gpus only
        while True:
            
            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

            outputs = self.__call__(
                **model_inputs,
                return_dict=True,
            )
            
            next_token_logits = outputs.logits[:, -1, :]
            next_token_logits = logits_processor(input_ids, next_token_logits)
            next_token_logits = logits_warper(input_ids, next_token_logits)
            probs = nn.functional.softmax(next_token_logits, dim=-1)
            next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)

            # finished sentences should have their next token be a padding token
            if eos_token_id is not None:
                if pad_token_id is None:
                    raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
                next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)

            # update generated ids, model inputs, and length for next step
            input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
            blocks_to_use -= 1

            model_kwargs = self._update_model_kwargs_for_generation(
                outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
            )

            # if eos_token was found in one sentence, set sentence to finished
            if eos_token_id_tensor is not None:
                unfinished_sequences = unfinished_sequences.mul(
                    next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
                )

                # stop when each sentence is finished
                if unfinished_sequences.max() == 0:
                    this_peer_finished = True

            # stop if we exceed the maximum length
            if stopping_criteria(input_ids, None):
                this_peer_finished = True

            if blocks_to_use <= 0 or this_peer_finished == True:
                blocks_to_use = block_len

                if split_by_prompt_text:
                    responses = extract_responses(input_ids, self.tokenizer, prompt=prompt)
                else:
                    responses = extract_responses(input_ids, self.tokenizer, prompt_len=prompt_len)
                if eos_strings:
                    responses, unfinished_sequences = get_truncated_responses(responses, eos_strings)
                beam_scores = scorer(
                    {
                        "response": responses,
                        "eos": unfinished_sequences == 0,
                    },
                )

                _, beam_idx = torch.topk(
                    beam_scores, 
                    num_beams, dim=0, largest=True, sorted=True
                )

                # repeat beam_idx by candidate numbers
                beam_idx = beam_idx.repeat(num_candidates)

                # if unfinished_sequences.min().item() == 0:
                #     breakpoint()

                # reorder states
                input_ids = input_ids[beam_idx, :]
                unfinished_sequences = unfinished_sequences[beam_idx]

                if unfinished_sequences.max().item() == 0:
                    this_peer_finished = True 

                if model_kwargs["past_key_values"] is not None:
                    model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx)

                if this_peer_finished == True:
                    break

        if return_dict_in_generate:
            return {
                "output_ids": input_ids[:num_beams],
                "scores": beam_scores[:num_beams],
            }
        else:
            return input_ids[0, None]