Skip to content

Commit fbd4cca

Browse files
author
Grzegorz Pluto-Prondzinski
authored
Migrate common_language dataset to Datasets v4.0.0 and switch audio decoding to SoundFile (#2287)
1 parent 3fc5773 commit fbd4cca

5 files changed

Lines changed: 69 additions & 39 deletions

File tree

examples/audio-classification/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@ datasets[audio]>=4.0.0
22
evaluate
33
numba==0.60.0
44
librosa
5+
soundfile

examples/audio-classification/run_audio_classification.py

Lines changed: 57 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,19 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import io
1617
import logging
1718
import os
1819
import sys
1920
from dataclasses import dataclass, field
2021
from random import randint
2122
from typing import Optional
2223

23-
import datasets
2424
import evaluate
2525
import numpy as np
26+
import soundfile as sf
2627
import transformers
27-
from datasets import DatasetDict, load_dataset
28+
from datasets import Audio, DatasetDict, load_dataset
2829
from transformers import AutoConfig, AutoFeatureExtractor, AutoModelForAudioClassification, HfArgumentParser
2930
from transformers.trainer_utils import get_last_checkpoint
3031
from transformers.utils import check_min_version, send_example_telemetry
@@ -50,6 +51,9 @@ def check_optimum_habana_min_version(*a, **b):
5051

5152
require_version("datasets>=4.0.0", "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt")
5253

54+
# Disable torchcodec decoding in datasets before any dataset ops
55+
os.environ.setdefault("HF_DATASETS_DISABLE_TORCHCODEC", "1")
56+
5357

5458
def random_subsample(wav: np.ndarray, max_length: float, sample_rate: int = 16000):
5559
"""Randomly sample chunks of `max_length` seconds from the input audio"""
@@ -280,14 +284,12 @@ def main():
280284
data_args.dataset_config_name,
281285
split=data_args.train_split_name,
282286
token=model_args.token,
283-
trust_remote_code=model_args.trust_remote_code,
284287
)
285288
raw_datasets["eval"] = load_dataset(
286289
data_args.dataset_name,
287290
data_args.dataset_config_name,
288291
split=data_args.eval_split_name,
289292
token=model_args.token,
290-
trust_remote_code=model_args.trust_remote_code,
291293
)
292294

293295
if data_args.audio_column_name not in raw_datasets["train"].column_names:
@@ -315,52 +317,84 @@ def main():
315317
trust_remote_code=model_args.trust_remote_code,
316318
)
317319

318-
# `datasets` takes care of automatically loading and resampling the audio,
319-
# so we just need to set the correct target sampling rate.
320+
# Make sure datasets does not auto-decode audio (we'll open via soundfile in prepare_dataset).
320321
raw_datasets = raw_datasets.cast_column(
321-
data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
322+
data_args.audio_column_name,
323+
Audio(sampling_rate=feature_extractor.sampling_rate, decode=False),
322324
)
323325

324326
# Max input length
325327
max_length = int(round(feature_extractor.sampling_rate * data_args.max_length_seconds))
326328

327329
model_input_name = feature_extractor.model_input_names[0]
328330

331+
def load_and_validate_audio(sample, feature_extractor, subsample: bool = False, max_length: float = None):
332+
"""
333+
Open audio via soundfile, downmix to mono if needed, validate sample rate,
334+
and optionally apply random subsampling.
335+
"""
336+
path = sample.get("path")
337+
wav, sr = None, None
338+
339+
if isinstance(path, str):
340+
try:
341+
wav, sr = sf.read(path, dtype="float32", always_2d=False)
342+
except Exception:
343+
wav, sr = None, None
344+
345+
if wav is None:
346+
raw = sample.get("bytes")
347+
if not raw:
348+
raise RuntimeError(f"Cannot open audio sample: {sample}")
349+
fileobj = io.BytesIO(raw)
350+
wav, sr = sf.read(fileobj, dtype="float32", always_2d=False)
351+
352+
if wav.ndim > 1:
353+
wav = wav.mean(axis=1)
354+
355+
if sr != feature_extractor.sampling_rate:
356+
raise RuntimeError(f"Expected {feature_extractor.sampling_rate} Hz, but got {sr} Hz for {path}")
357+
358+
if subsample and max_length is not None:
359+
wav = random_subsample(wav, max_length=max_length, sample_rate=sr)
360+
361+
return wav
362+
329363
def train_transforms(batch):
330364
"""Apply train_transforms across a batch."""
331-
subsampled_wavs = []
332-
333-
for audio in batch[data_args.audio_column_name]:
334-
wav = random_subsample(
335-
audio["array"], max_length=data_args.max_length_seconds, sample_rate=feature_extractor.sampling_rate
336-
)
337-
subsampled_wavs.append(wav)
365+
subsampled_wavs = [
366+
load_and_validate_audio(sample, feature_extractor, subsample=True, max_length=data_args.max_length_seconds)
367+
for sample in batch[data_args.audio_column_name]
368+
]
338369
inputs = feature_extractor(
339370
subsampled_wavs,
340371
max_length=max_length,
341372
sampling_rate=feature_extractor.sampling_rate,
342373
padding="max_length",
343374
truncation=True,
344375
)
345-
output_batch = {model_input_name: inputs.get(model_input_name)}
346-
output_batch["labels"] = list(batch[data_args.label_column_name])
347-
348-
return output_batch
376+
return {
377+
model_input_name: inputs.get(model_input_name),
378+
"labels": list(batch[data_args.label_column_name]),
379+
}
349380

350381
def val_transforms(batch):
351382
"""Apply val_transforms across a batch."""
352-
wavs = [audio["array"] for audio in batch[data_args.audio_column_name]]
383+
wavs = [
384+
load_and_validate_audio(sample, feature_extractor, subsample=False)
385+
for sample in batch[data_args.audio_column_name]
386+
]
353387
inputs = feature_extractor(
354388
wavs,
355389
max_length=max_length,
356390
sampling_rate=feature_extractor.sampling_rate,
357391
padding="max_length",
358392
truncation=True,
359393
)
360-
output_batch = {model_input_name: inputs.get(model_input_name)}
361-
output_batch["labels"] = list(batch[data_args.label_column_name])
362-
363-
return output_batch
394+
return {
395+
model_input_name: inputs.get(model_input_name),
396+
"labels": list(batch[data_args.label_column_name]),
397+
}
364398

365399
# Prepare label mappings.
366400
# We'll include these in the model's config to get human readable labels in the Inference API.

tests/configs/examples/ast_finetuned_speech_commands_v2.json

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"gaudi2": {
3-
"common_language": {
3+
"regisss/common_language": {
44
"num_train_epochs": 10,
55
"eval_batch_size": 64,
66
"distribution": {
@@ -24,15 +24,14 @@
2424
"--dataloader_num_workers 1",
2525
"--ignore_mismatched_sizes=True",
2626
"--use_hpu_graphs_for_training",
27-
"--use_hpu_graphs_for_inference",
28-
"--trust_remote_code True"
27+
"--use_hpu_graphs_for_inference"
2928
]
3029
}
3130
}
3231
}
3332
},
3433
"gaudi3": {
35-
"common_language": {
34+
"regisss/common_language": {
3635
"num_train_epochs": 10,
3736
"eval_batch_size": 64,
3837
"distribution": {
@@ -56,8 +55,7 @@
5655
"--dataloader_num_workers 1",
5756
"--ignore_mismatched_sizes=True",
5857
"--use_hpu_graphs_for_training",
59-
"--use_hpu_graphs_for_inference",
60-
"--trust_remote_code True"
58+
"--use_hpu_graphs_for_inference"
6159
]
6260
}
6361
}

tests/configs/examples/wav2vec2_base.json

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"gaudi1": {
3-
"common_language": {
3+
"regisss/common_language": {
44
"num_train_epochs": 10,
55
"eval_batch_size": 64,
66
"distribution": {
@@ -23,15 +23,14 @@
2323
"--seed 0",
2424
"--dataloader_num_workers 1",
2525
"--use_hpu_graphs_for_training",
26-
"--use_hpu_graphs_for_inference",
27-
"--trust_remote_code True"
26+
"--use_hpu_graphs_for_inference"
2827
]
2928
}
3029
}
3130
}
3231
},
3332
"gaudi2": {
34-
"common_language": {
33+
"regisss/common_language": {
3534
"num_train_epochs": 5,
3635
"eval_batch_size": 64,
3736
"distribution": {
@@ -54,15 +53,14 @@
5453
"--seed 0",
5554
"--dataloader_num_workers 1",
5655
"--use_hpu_graphs_for_training",
57-
"--use_hpu_graphs_for_inference",
58-
"--trust_remote_code True"
56+
"--use_hpu_graphs_for_inference"
5957
]
6058
}
6159
}
6260
}
6361
},
6462
"gaudi3": {
65-
"common_language": {
63+
"regisss/common_language": {
6664
"num_train_epochs": 5,
6765
"eval_batch_size": 64,
6866
"distribution": {
@@ -85,8 +83,7 @@
8583
"--seed 0",
8684
"--dataloader_num_workers 1",
8785
"--use_hpu_graphs_for_training",
88-
"--use_hpu_graphs_for_inference",
89-
"--trust_remote_code True"
86+
"--use_hpu_graphs_for_inference"
9087
]
9188
}
9289
}

tests/test_examples.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -936,7 +936,7 @@ class MultiCardMaskedLanguageModelingExampleTester(
936936
class MultiCardAudioClassificationExampleTester(
937937
ExampleTesterBase, metaclass=ExampleTestMeta, example_name="run_audio_classification", multi_card=True
938938
):
939-
TASK_NAME = "common_language"
939+
TASK_NAME = "regisss/common_language"
940940

941941

942942
class MultiCardSpeechRecognitionExampleTester(

0 commit comments

Comments
 (0)