# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union

import torch
import torch.nn as nn
from fairseq2.data import Collater
from fairseq2.data.audio import AudioDecoder, WaveformToFbankConverter
from fairseq2.data.text.text_tokenizer import TextTokenizer
from fairseq2.data.typing import StringLike
from fairseq2.generation import SequenceToTextOutput, SequenceGeneratorOptions
from fairseq2.generation.logits_processor import LogitsProcessor
from fairseq2.memory import MemoryBlock
from fairseq2.typing import DataType, Device
from torch import Tensor
from enum import Enum, auto

from . import (
    UnitYGenerator,
    UnitYModel,
    load_m4t_model,
    load_m4t_tokenizer,
)


class NGramRepeatBlockProcessor(LogitsProcessor):
    def __init__(self, no_repeat_ngram_size: int) -> None:
        self.no_repeat_ngram_size = no_repeat_ngram_size

    def __call__(self, seqs: Tensor, lprobs: Tensor) -> None:
        """Remove repeating n-gram tokens."""
        batch_size, beam_size, vocab_size = lprobs.size()
        step_nr = seqs.size(2) - 1
        # (N, B, S) -> (N * B, S)
        seqs = seqs.view(-1, seqs.size(2))
        # (N, B, V) -> (N * B, V)
        lprobs = lprobs.view(-1, vocab_size)
        self._no_repeat_ngram(seqs, lprobs, batch_size, beam_size, step_nr)

    def _no_repeat_ngram(
        self,
        seqs: Tensor,
        lprobs: Tensor,
        batch_size: int,
        beam_size: int,
        step_nr: int,
    ) -> Tensor:
        """For each hypothesis generate a list of previous ngrams
            and set associated lprobs to -inf

        :param seqs: The generated sequences of tokens for the first
            `step_nr` steps of decoding (N * B, step_nr + 1)
        :param lprobs: The next-step log probability reshaped to (N * B, V)
        :param batch_size: The batch size.
        :param beam_size: The beam size.
        :param step_nr: Step number for decoding.

        :returns:
            modified lprobs tensor with banned tokens set to -inf
        """
        banned_tokens = [[] for _ in range(batch_size * beam_size)]

        if step_nr + 2 - self.no_repeat_ngram_size >= 0:
            cpu_tokens: List[List[int]] = seqs.cpu().tolist()
            check_start_pos = step_nr + 2 - self.no_repeat_ngram_size
            for bbsz_idx in range(batch_size * beam_size):
                ngram_to_check = cpu_tokens[bbsz_idx][
                    -(self.no_repeat_ngram_size - 1) :
                ]
                for i in range(check_start_pos):
                    if (
                        ngram_to_check
                        == cpu_tokens[bbsz_idx][i : i + self.no_repeat_ngram_size - 1]
                    ):
                        banned_tokens[bbsz_idx].append(
                            cpu_tokens[bbsz_idx][i + self.no_repeat_ngram_size - 1]
                        )
        for bbsz_idx in range(batch_size * beam_size):
            lprobs[bbsz_idx, banned_tokens[bbsz_idx]] = -torch.inf
        return lprobs



class Task(Enum):
    S2ST = auto()
    S2TT = auto()
    T2ST = auto()
    T2TT = auto()
    ASR = auto()


class Modality(Enum):
    SPEECH = "speech"
    TEXT = "text"


class Translator(nn.Module):
    def __init__(
        self,
        model_path: str,
        device: Device = torch.device("cuda"),
        dtype: DataType = torch.float16,
    ):
        super().__init__()
        # Load the model.
        if device == torch.device("cpu"):
            dtype = torch.float32
        self.model: UnitYModel = load_m4t_model(model_path, device, dtype)
        self.text_tokenizer = load_m4t_tokenizer(model_path)
        self.device = device
        self.decode_audio = AudioDecoder(dtype=torch.float32, device=device)
        self.convert_to_fbank = WaveformToFbankConverter(
            num_mel_bins=80,
            waveform_scale=2**15,
            channel_last=True,
            standardize=True,
            device=device,
            dtype=dtype,
        )
        self.collate = Collater(
            pad_idx=self.text_tokenizer.vocab_info.pad_idx, pad_to_multiple=2
        )
        
    @classmethod
    def get_prediction(
        cls,
        model: UnitYModel,
        text_tokenizer: TextTokenizer,
        src: Dict[str, Tensor],
        input_modality: Modality,
        output_modality: Modality,
        tgt_lang: str,
        ngram_filtering: bool = False,
        text_max_len_a: int = 1,
        text_max_len_b: int = 200,
    ) -> SequenceToTextOutput:

        text_opts = SequenceGeneratorOptions(
            beam_size=5, soft_max_seq_len=(text_max_len_a, text_max_len_b)
        )

        if ngram_filtering:
            text_opts.logits_processor = NGramRepeatBlockProcessor(
                no_repeat_ngram_size=4
            )
        generator = UnitYGenerator(
            model,
            text_tokenizer,
            tgt_lang,
            text_opts=text_opts,
        )
        return generator(
            src["seqs"],
            src["seq_lens"],
            input_modality.value,
            output_modality.value,
            ngram_filtering=ngram_filtering,
        )

    def get_modalities_from_task(self, task: Task) -> Tuple[Modality, Modality]:
        if task == Task.S2ST:
            return Modality.SPEECH, Modality.SPEECH
        # ASR is treated as S2TT with src_lang == tgt_lang
        elif task == Task.S2TT or task == Task.ASR:
            return Modality.SPEECH, Modality.TEXT
        elif task == Task.T2TT:
            return Modality.TEXT, Modality.TEXT
        else:
            return Modality.TEXT, Modality.SPEECH

    @torch.inference_mode()
    def predict(
        self,
        input: Union[str, Tensor],
        task_str: str,
        tgt_lang: str,
        src_lang: Optional[str] = None,
        ngram_filtering: bool = False,
        sample_rate: int = 16000,
        text_max_len_a: int = 1,
        text_max_len_b: int = 200,

    ) -> Tuple[StringLike, Optional[Tensor], Optional[int]]:
        """
        The main method used to perform inference on all tasks.

        :param input:
            Either text or path to audio or audio Tensor.
        :param task_str:
            String representing the task.
            Valid choices are "S2ST", "S2TT", "T2ST", "T2TT", "ASR"
        :param tgt_lang:
            Target language to decode into.
        :param src_lang:
            Source language of input, only required for T2ST, T2TT tasks.

        :returns:
            - Translated text.
            - Generated output audio waveform corresponding to the translated text.
            - Sample rate of output audio waveform.
        """
        try:
            task = Task[task_str.upper()]
        except KeyError:
            raise ValueError(f"Unsupported task: {task_str}")

        input_modality, output_modality = self.get_modalities_from_task(task)

        if input_modality == Modality.SPEECH:
            audio = input
            if isinstance(audio, str):
                with Path(audio).open("rb") as fb:
                    block = MemoryBlock(fb.read())
                decoded_audio = self.decode_audio(block)
            else:
                decoded_audio = {
                    "waveform": audio,
                    "sample_rate": sample_rate,
                    "format": -1,
                }
            src = self.collate(self.convert_to_fbank(decoded_audio))["fbank"]
        else:
            if src_lang is None:
                raise ValueError("src_lang must be specified for T2ST, T2TT tasks.")

            text = input
            self.token_encoder = self.text_tokenizer.create_encoder(
                task="translation", lang=src_lang, mode="source", device=self.device
            )
            src = self.collate(self.token_encoder(text))

        result = self.get_prediction(
            self.model,
            self.text_tokenizer,
            src,
            input_modality,
            output_modality,
            tgt_lang=tgt_lang,
            ngram_filtering=ngram_filtering,
            text_max_len_a=text_max_len_a,
            text_max_len_b=text_max_len_b,
        )

        return result.sentences[0]

