Skip to content

Commit 0958b5d

Browse files
committed
Update LLaST implementation
1 parent eb72c6a commit 0958b5d

19 files changed

+1108
-17
lines changed

.pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ repos:
44
rev: 5.0.4
55
hooks:
66
- id: flake8
7-
args: ["--exclude=xtuner/model/transformers_models/*"]
7+
args: ["--exclude=xtuner/model/transformers_models/*,xtuner/evaluation/metrics/sacrebleu.py"]
88
- repo: https://github.com/PyCQA/isort
99
rev: 5.11.5
1010
hooks:

xtuner/dataset/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
load_intern_repo_untokenized_dataset)
99
from .internvl_dataset import InternVL_V1_5_Dataset
1010
from .json_dataset import load_json_file
11+
from .llast import LLaSTDataset
1112
from .llava import LLaVADataset
1213
from .modelscope import process_ms_dataset
1314
from .moss_sft import MOSSSFTDataset
@@ -25,5 +26,5 @@
2526
'load_intern_repo_tokenized_dataset',
2627
'load_intern_repo_untokenized_dataset', 'build_packed_dataset',
2728
'RefCOCOJsonDataset', 'RefCOCOJsonEvalDataset', 'InvRefCOCOJsonDataset',
28-
'load_json_file', 'InternVL_V1_5_Dataset'
29+
'load_json_file', 'InternVL_V1_5_Dataset', 'LLaSTDataset'
2930
]
+4-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
from .default_collate_fn import default_collate_fn
3+
from .llast_collate_fn import llast_audiomask_mel_collate_fn
34
from .mmlu_collate_fn import mmlu_collate_fn
45

5-
__all__ = ['default_collate_fn', 'mmlu_collate_fn']
6+
__all__ = [
7+
'default_collate_fn', 'mmlu_collate_fn', 'llast_audiomask_mel_collate_fn'
8+
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Copyright (c) LLaST. All rights reserved.
2+
# Copyright (c) OpenMMLab. All rights reserved.
3+
from typing import Dict, Sequence
4+
5+
import torch
6+
from torch.nn.utils.rnn import pad_sequence
7+
8+
from xtuner.utils import (DEFAULT_PAD_TOKEN_INDEX, IGNORE_INDEX,
9+
LLAST_AUDIO_PADDING_TOKEN_INDEX)
10+
11+
12+
def llast_audiomask_mel_collate_fn(
13+
instances: Sequence[Dict],
14+
pad_index: int = DEFAULT_PAD_TOKEN_INDEX,
15+
return_hf_format: bool = False) -> Dict[str, torch.Tensor]:
16+
"""Add audio tokens and conduct padding operation."""
17+
input_ids = []
18+
labels = []
19+
feats_lens = []
20+
has_audio = any(inst.get('audio_tokens') is not None for inst in instances)
21+
22+
if has_audio:
23+
audio_tokens = []
24+
for example in instances:
25+
input_ids.append(torch.tensor(example['input_ids']))
26+
labels.append(torch.tensor(example['labels']))
27+
if has_audio:
28+
audio_tokens.append(example['audio_tokens'])
29+
feats_lens.append(torch.tensor(example['audio_lens']))
30+
if len(instances) > 1:
31+
input_ids = pad_sequence(
32+
input_ids, batch_first=True, padding_value=pad_index)
33+
labels = pad_sequence(
34+
labels, batch_first=True, padding_value=IGNORE_INDEX)
35+
# padding audio tokens
36+
padded_audio_tokens = pad_sequence(
37+
audio_tokens,
38+
batch_first=True,
39+
padding_value=LLAST_AUDIO_PADDING_TOKEN_INDEX)
40+
41+
else:
42+
input_ids = torch.stack(input_ids)
43+
labels = torch.stack(labels)
44+
padded_audio_tokens = torch.stack(audio_tokens)
45+
46+
data_dict = {
47+
'input_ids': input_ids,
48+
'attention_mask': input_ids.ne(pad_index),
49+
'labels': labels
50+
}
51+
52+
if has_audio:
53+
audio_lens = torch.stack(feats_lens)
54+
data_dict['audio_tokens'] = padded_audio_tokens
55+
data_dict['audio_lens'] = audio_lens
56+
57+
if return_hf_format:
58+
return data_dict
59+
else:
60+
return {'data': data_dict, 'data_samples': instances}

xtuner/dataset/huggingface.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ def add_template_to_dataset(dataset, template_map_fn, map_num_proc):
6565

6666

6767
def tokenize_dataset(dataset, tokenizer, max_length, with_image_token,
68-
input_ids_with_output, remove_unused_columns,
69-
map_num_proc):
68+
with_audio_token, input_ids_with_output,
69+
remove_unused_columns, map_num_proc):
7070
assert (tokenizer is not None) and (max_length is not None), \
7171
f'({tokenizer}, {max_length})'
7272
if isinstance(tokenizer, dict) or isinstance(
@@ -78,6 +78,7 @@ def tokenize_dataset(dataset, tokenizer, max_length, with_image_token,
7878
tokenizer=tokenizer,
7979
max_length=max_length,
8080
with_image_token=with_image_token,
81+
with_audio_token=with_audio_token,
8182
input_ids_with_output=input_ids_with_output),
8283
remove_columns=list(dataset.column_names)
8384
if remove_unused_columns else None,
@@ -112,6 +113,7 @@ def process(dataset,
112113
use_varlen_attn=False,
113114
input_ids_with_output=True,
114115
with_image_token=False,
116+
with_audio_token=False,
115117
map_num_proc=32):
116118
"""Post-process the dataset loaded from the Hugging Face Hub, or a local
117119
dataset.
@@ -153,6 +155,9 @@ def process(dataset,
153155
with_image_token: Whether to convert DEFAULT_IMAGE_TOKEN to
154156
IMAGE_TOKEN_INDEX. Typically set it to True during the training
155157
of VLM.
158+
with_audio_token: Whether to convert DEFAULT_AUDIO_TOKEN to
159+
LLAST_AUDIO_TOKEN_INDEX. Typically set it to True during the
160+
training of SLM.
156161
map_num_proc: Max number of processes when mapping the dataset.
157162
"""
158163
if use_varlen_attn:
@@ -197,7 +202,8 @@ def process(dataset,
197202

198203
if do_dataset_tokenization:
199204
dataset = tokenize_dataset(dataset, tokenizer, max_length,
200-
with_image_token, input_ids_with_output,
205+
with_image_token, with_audio_token,
206+
input_ids_with_output,
201207
remove_unused_columns, map_num_proc)
202208

203209
if input_ids_with_output:
@@ -213,7 +219,7 @@ def process(dataset,
213219
shuffle_before_pack, map_num_proc)
214220

215221
# add 'length'
216-
dataset = dataset.map(get_lengths, num_proc=map_num_proc)
222+
dataset = dataset.map(get_lengths, num_proc=1)
217223
setattr(dataset, 'length', dataset['length'])
218224

219225
return dataset
@@ -234,6 +240,7 @@ def process_hf_dataset(dataset,
234240
use_varlen_attn=False,
235241
input_ids_with_output=True,
236242
with_image_token=False,
243+
with_audio_token=False,
237244
map_num_proc=32):
238245
"""Post-process the dataset loaded from the Hugging Face Hub, or a local
239246
dataset.
@@ -275,6 +282,9 @@ def process_hf_dataset(dataset,
275282
with_image_token: Whether to convert DEFAULT_IMAGE_TOKEN to
276283
IMAGE_TOKEN_INDEX. Typically set it to True during the training
277284
of VLM.
285+
with_audio_token: Whether to convert DEFAULT_AUDIO_TOKEN to
286+
LLAST_AUDIO_TOKEN_INDEX. Typically set it to True during the
287+
training of SLM.
278288
map_num_proc: Max number of processes when mapping the dataset.
279289
"""
280290
kwargs = dict(
@@ -293,6 +303,7 @@ def process_hf_dataset(dataset,
293303
use_varlen_attn=use_varlen_attn,
294304
input_ids_with_output=input_ids_with_output,
295305
with_image_token=with_image_token,
306+
with_audio_token=with_audio_token,
296307
map_num_proc=map_num_proc)
297308
if not (dist.is_available() and dist.is_initialized()):
298309
return process(**kwargs)

0 commit comments

Comments
 (0)