# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Sequence, Tuple, Union
import logging

import torch
import math
from torch.utils.data import DataLoader
from mmengine.runner import IterBasedTrainLoop
from mmengine.evaluator import Evaluator
from mmengine.runner.base_loop import BaseLoop
from mmengine.runner.amp import autocast
from mmengine.logging import print_log
from xtuner.registry import BUILDER
from transformers import GenerationConfig, StoppingCriteriaList

from xtuner.model.utils import prepare_inputs_labels_for_multimodal


class TrainLoop(IterBasedTrainLoop):

    def __init__(self,
                 runner,
                 dataloader: Union[DataLoader, Dict],
                 max_iters: Optional[int] = None,
                 max_epochs: Union[int, float] = None,
                 **kwargs) -> None:

        if max_iters is None and max_epochs is None:
            raise RuntimeError('Please specify the `max_iters` or '
                               '`max_epochs` in `train_cfg`.')
        elif max_iters is not None and max_epochs is not None:
            raise RuntimeError('Only one of `max_iters` or `max_epochs` can '
                               'exist in `train_cfg`.')
        else:
            if max_iters is not None:
                iters = int(max_iters)
                assert iters == max_iters, ('`max_iters` should be a integer '
                                            f'number, but get {max_iters}')
            elif max_epochs is not None:
                if isinstance(dataloader, dict):
                    diff_rank_seed = runner._randomness_cfg.get(
                        'diff_rank_seed', False)
                    dataloader = runner.build_dataloader(
                        dataloader,
                        seed=runner.seed,
                        diff_rank_seed=diff_rank_seed)
                iters = max_epochs * len(dataloader)
            else:
                raise NotImplementedError
        super().__init__(
            runner=runner, dataloader=dataloader, max_iters=iters, **kwargs)

class TestLoop(BaseLoop):
    """Loop for test.

    Args:
        runner (Runner): A reference of runner.
        dataloader (Dataloader or dict): A dataloader object or a dict to
            build a dataloader.
        evaluator (Evaluator or dict or list): Used for computing metrics.
        fp16 (bool): Whether to enable fp16 testing. Defaults to
            False.
    """

    def __init__(self,
                 runner,
                 dataloader: Union[DataLoader, Dict],
                 evaluator: Union[Evaluator, Dict, List],
                 tokenizer,
                 work_dir,
                 task_name='asr',
                 max_new_tokens=600,
                 fp16: bool = False):
        super().__init__(runner, dataloader)

        if isinstance(evaluator, dict) or isinstance(evaluator, list):
            self.evaluator = runner.build_evaluator(evaluator)  # type: ignore
        else:
            self.evaluator = evaluator  # type: ignore
        if hasattr(self.dataloader.dataset, 'metainfo'):
            self.evaluator.dataset_meta = self.dataloader.dataset.metainfo
            self.runner.visualizer.dataset_meta = \
                self.dataloader.dataset.metainfo
        else:
            print_log(
                f'Dataset {self.dataloader.dataset.__class__.__name__} has no '
                'metainfo. ``dataset_meta`` in evaluator, metric and '
                'visualizer will be None.',
                logger='current',
                level=logging.WARNING)
        self.fp16 = fp16
        self.tokenizer = BUILDER.build(tokenizer)
        self.work_dir = work_dir
        self.task_name = task_name
        assert self.task_name in ["asr", "conv"]

        # default generation config
        self.max_new_tokens = max_new_tokens
        self.gen_config = GenerationConfig(
            max_new_tokens=max_new_tokens,
            do_sample=False,
            # temperature=0.1,
            # top_p=0.75,
            # top_k=40,
            eos_token_id=self.tokenizer.eos_token_id,
            pad_token_id=self.tokenizer.pad_token_id
            if self.tokenizer.pad_token_id is not None else
            self.tokenizer.eos_token_id,
        )

        self.stop_criteria = StoppingCriteriaList()

    def run(self) -> dict:
        """Launch test."""
        self.runner.call_hook('before_test')
        self.runner.call_hook('before_test_epoch')
        self.runner.model.eval()
        out_file = {}

        for idx, data_batch in enumerate(self.dataloader):
            outputs = self.run_iter(idx, data_batch)
            if data_batch['data']['split'] not in out_file:
                out_file[data_batch['data']['split']] = open(f"{self.work_dir}/{data_batch['data']['split']}", 'w')

            if self.task_name == 'asr':
                print(f"{outputs.upper()} ({data_batch['data']['wav_id']})", file=out_file[data_batch['data']['split']])
            elif self.task_name == 'conv':
                print(f"{outputs}\t({data_batch['data']['wav_id']})", file=out_file[data_batch['data']['split']])
            # self.runner.logger.info(f"{outputs}\t({data_batch['data']['wav_id']})")

            out_file[data_batch['data']['split']].flush()

        for split in out_file:
            out_file[split].close()
        # compute metrics
        metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
        self.runner.call_hook('after_test_epoch', metrics=metrics)
        self.runner.call_hook('after_test')
        return metrics

    @torch.no_grad()
    def run_iter(self, idx, data_batch: Sequence[dict]) -> None:
        """Iterate one mini-batch.

        Args:
            data_batch (Sequence[dict]): Batch of data from dataloader.
        """
        self.runner.call_hook(
            'before_test_iter', batch_idx=idx, data_batch=data_batch)
        # predictions should be sequence of BaseDataElement
        with autocast(enabled=self.fp16):
            assert data_batch['data']['input_ids'].size(0) == 1
            if hasattr(self.runner.model, 'module'):
                model = self.runner.model.module
            else:
                model = self.runner.model
            if 'speech_repr' in data_batch['data']:
                speech_outputs = model.forward_whisper_encoder(data_batch['data']['speech_repr'])
                speech_repr = model.projector(speech_outputs)
                data_batch['data']['speech_repr'] = speech_repr
                data_batch['data']['repr_lens'] = [math.ceil(l / model.pooling) for l in data_batch['data']['repr_lens']]

            if 'Llama-2' in model.config._name_or_path:
                data_batch['data']['input_ids'] = data_batch['data']['input_ids'][:, :-3]
                data_batch['data']['attention_mask'] = data_batch['data']['attention_mask'][:, :-3]
            elif 'Llama-3' in model.config._name_or_path:
                data_batch['data']['input_ids'] = data_batch['data']['input_ids'][:, :-1]
                data_batch['data']['attention_mask'] = data_batch['data']['attention_mask'][:, :-1]
            
            if 'speech_repr' in data_batch['data']:
                data_for_llm = prepare_inputs_labels_for_multimodal(
                                    llm=model.llm, 
                                    input_ids=data_batch['data']['input_ids'],
                                    speech_repr=data_batch['data']['speech_repr'],
                                    repr_lens=data_batch['data']['repr_lens'],
                                )
            else:
                data_for_llm = {"input_ids": data_batch['data']['input_ids'].to('cuda')}

            if 'Llama-2' in model.config._name_or_path:
                generation_output = model.generate(
                    **data_for_llm,
                    max_new_tokens=self.max_new_tokens,
                    generation_config=self.gen_config,
                    bos_token_id=self.tokenizer.bos_token_id,
                    stopping_criteria=self.stop_criteria)
            elif 'Llama-3' in model.config._name_or_path:
                terminators = [self.tokenizer.eos_token_id, self.tokenizer.convert_tokens_to_ids("<|eot_id|>")]
                generation_output = model.generate(
                    **data_for_llm,
                    max_new_tokens=self.max_new_tokens,
                    generation_config=self.gen_config,
                    bos_token_id=self.tokenizer.bos_token_id,
                    stopping_criteria=self.stop_criteria,
                    eos_token_id=terminators)

            outputs = self.tokenizer.decode(generation_output[0], skip_special_tokens=True).strip()
            if '[/INST]' in outputs:
                outputs = outputs.split('[/INST]')[-1].strip()
            # outputs = self.runner.model.test_step(data_batch)
        self.evaluator.process(data_samples=outputs, data_batch=data_batch)
        self.runner.call_hook(
            'after_test_iter',
            batch_idx=idx,
            data_batch=data_batch,
            outputs=outputs)
        return outputs
