# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Sequence

import torch
from torch.nn.utils.rnn import pad_sequence

from xtuner.utils import DEFAULT_PAD_TOKEN_INDEX, IGNORE_INDEX


def default_collate_fn(instances: Sequence[Dict],
                       pad_index: int = DEFAULT_PAD_TOKEN_INDEX,
                       return_hf_format: bool = False,
                       use_varlen_attn: bool = False):

    input_ids, labels = [], []
    has_image = any(inst.get('pixel_values') is not None for inst in instances)
    has_speech = any(inst.get('mel_feature') is not None for inst in instances)
    if use_varlen_attn:
        cumulative_len, indexes = [], []
        assert len(instances) == 1, (
            f'If utilizing local attention, the batch size should be'
            f' set to 1, but got {len(instances)}')
        assert not (has_image or has_speech), 'Currently, it is not configured to '
        'accommodate the use of varlen Attention in multimodal training'

    if has_image:
        pixel_values = []
    
    if has_speech:
        mel_features = []
        mel_lens = []

    for example in instances:
        input_ids.append(torch.LongTensor(example['input_ids']))
        labels.append(torch.LongTensor(example['labels']))
        if use_varlen_attn:
            cumulative_len.append(torch.IntTensor(example['cumulative_len']))
            indexes.append(torch.LongTensor(example['indexes']))

        if has_image:
            pixel_values.append(example['pixel_values'])
        
        if has_speech:
            mel_features.append(example['mel_feature'])
            mel_lens.append(example['mel_len'])

    if len(instances) > 1:
        input_ids = pad_sequence(
            input_ids, batch_first=True, padding_value=pad_index)
        labels = pad_sequence(
            labels, batch_first=True, padding_value=IGNORE_INDEX)
    else:
        input_ids = torch.stack(input_ids)
        labels = torch.stack(labels)

    if use_varlen_attn:
        indexes = torch.stack(indexes, dim=0)
        max_seqlen = (
            cumulative_len[0][1:] -  # noqa: W504
            cumulative_len[0][:-1]).max().item()
        data_dict = {
            'input_ids': input_ids,
            'cumulative_len': cumulative_len,
            'indexes': indexes,
            'labels': labels,
            'max_seqlen': max_seqlen
        }
    else:
        data_dict = {
            'input_ids': input_ids,
            'attention_mask': input_ids.ne(pad_index),
            'labels': labels
        }

    if has_image:
        pixel_values = torch.stack(pixel_values)
        data_dict['pixel_values'] = pixel_values

    if has_speech:
        # mel_features_new = mel_features[0].new_zeros(len(mel_features), max(mel_lens), mel_features[0].size(1))
        # for i in range(len(mel_features)):
            # mel_features_new[i, :mel_lens[i]] = mel_features[i]
        data_dict['speech_repr'] = torch.cat(mel_features, dim=0)
        data_dict['repr_lens'] = mel_lens
    if 'wav_id' in instances[0]:
        data_dict['wav_id'] = instances[0]['wav_id']
    if 'split' in instances[0]:
        data_dict['split'] = instances[0]['split']

    if return_hf_format:
        return data_dict
    else:
        if 'split' in instances[0] and instances[0]['split'] is not None and 'test' in instances[0]['split'] and 'target' in instances[0]:
            data_samples = {
                'references': [[reply['reply'] for reply in inst['target']] for inst in instances]
            }
            return {'data': data_dict, 'data_samples': data_samples}
        else:
            return {'data': data_dict, 'data_samples': None}
