# Copyright (c) OpenMMLab. All rights reserved.
import json
import os
import math

import torch
import torchaudio

from glob import glob
from datasets import load_dataset
from mmengine.config import Config, ConfigDict
from torch.utils.data import Dataset
from tqdm import tqdm

from xtuner.registry import BUILDER
from .huggingface import process_hf_dataset
from .utils import expand2square


class LLaMASpeechDataset(Dataset):

    def __init__(self,
                 data_path,
                 tokenizer,
                 split,
                 speech_processor=None,
                 max_dataset_length=None,
                 dataset_map_fn=None,
                 template_map_fn=None,
                 max_length=2048):
        super().__init__()

        if os.path.isfile(data_path) and data_path.endswith('.json'):
            data_files = {
                split: data_path
            }

            speech_dataset = load_dataset('json', data_files=data_files)
        else:
            speech_dataset = load_dataset(data_path)
            speech_dataset = speech_dataset.remove_columns("audio")

        self.text_data = process_hf_dataset(
            dataset=speech_dataset,
            tokenizer=tokenizer,
            max_length=max_length,
            dataset_map_fn=dataset_map_fn,
            template_map_fn=template_map_fn,
            split=split,
            max_dataset_length=max_dataset_length,
            remove_unused_columns=False,
            pack_to_max_length=False,
            with_image_token=False,
            with_speech_token=True)

        self.split = split
        self.is_test = 'test' in split

        self.tokenizer = BUILDER.build(tokenizer)

        if isinstance(speech_processor, dict) or isinstance(
                speech_processor, Config) or isinstance(speech_processor,
                                                       ConfigDict):
            self.speech_processor = BUILDER.build(speech_processor)
        else:
            self.speech_processor = speech_processor

    @property
    def modality_length(self):
        length_list = []
        for data_dict in tqdm(self.text_data):
            cur_len = len(data_dict['input_ids'])
            length_list.append(cur_len)
        return length_list

    def __len__(self):
        return len(self.text_data)

    def __getitem__(self, index):
        data_dict = self.text_data[index]
        if self.speech_processor is not None:
            if 'wav_path' in data_dict:
                wav_path = data_dict['wav_path']
            else:
                paths = data_dict['file'].split('/')
                wav_path = [p for p in glob(f"{'/'.join(paths[:-1])}/*/*/{data_dict['speaker_id']}/{data_dict['chapter_id']}/{paths[-1]}")][0]
            wav_signals, sr = torchaudio.load(wav_path)
            if sr != 16000:
                wav_signals = torchaudio.functional.resample(wav_signals, orig_freq=sr, new_freq=16000)
            mel_feature = self.speech_processor(wav_signals[0], return_tensors="pt", sampling_rate=16000).input_features.cuda()
            data_dict['mel_feature'] = mel_feature
            data_dict['mel_feature'].requires_grad = True
            data_dict['mel_len'] = math.ceil(50 * wav_signals.size(1) / 16000)
        if self.is_test:
            if 'utt_id' in data_dict:
                data_dict['wav_id'] = data_dict['utt_id']
            else:
                data_dict['wav_id'] = paths[-1][:-5]
            data_dict['split'] = self.split

        return data_dict
