1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515
16+ import io
1617import logging
1718import os
1819import sys
1920from dataclasses import dataclass , field
2021from random import randint
2122from typing import Optional
2223
23- import datasets
2424import evaluate
2525import numpy as np
26+ import soundfile as sf
2627import transformers
27- from datasets import DatasetDict , load_dataset
28+ from datasets import Audio , DatasetDict , load_dataset
2829from transformers import AutoConfig , AutoFeatureExtractor , AutoModelForAudioClassification , HfArgumentParser
2930from transformers .trainer_utils import get_last_checkpoint
3031from transformers .utils import check_min_version , send_example_telemetry
@@ -50,6 +51,9 @@ def check_optimum_habana_min_version(*a, **b):
5051
5152require_version ("datasets>=4.0.0" , "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt" )
5253
54+ # Disable torchcodec decoding in datasets before any dataset ops
55+ os .environ .setdefault ("HF_DATASETS_DISABLE_TORCHCODEC" , "1" )
56+
5357
5458def random_subsample (wav : np .ndarray , max_length : float , sample_rate : int = 16000 ):
5559 """Randomly sample chunks of `max_length` seconds from the input audio"""
@@ -280,14 +284,12 @@ def main():
280284 data_args .dataset_config_name ,
281285 split = data_args .train_split_name ,
282286 token = model_args .token ,
283- trust_remote_code = model_args .trust_remote_code ,
284287 )
285288 raw_datasets ["eval" ] = load_dataset (
286289 data_args .dataset_name ,
287290 data_args .dataset_config_name ,
288291 split = data_args .eval_split_name ,
289292 token = model_args .token ,
290- trust_remote_code = model_args .trust_remote_code ,
291293 )
292294
293295 if data_args .audio_column_name not in raw_datasets ["train" ].column_names :
@@ -315,52 +317,84 @@ def main():
315317 trust_remote_code = model_args .trust_remote_code ,
316318 )
317319
318- # `datasets` takes care of automatically loading and resampling the audio,
319- # so we just need to set the correct target sampling rate.
320+ # Make sure datasets does not auto-decode audio (we'll open via soundfile in prepare_dataset).
320321 raw_datasets = raw_datasets .cast_column (
321- data_args .audio_column_name , datasets .features .Audio (sampling_rate = feature_extractor .sampling_rate )
322+ data_args .audio_column_name ,
323+ Audio (sampling_rate = feature_extractor .sampling_rate , decode = False ),
322324 )
323325
324326 # Max input length
325327 max_length = int (round (feature_extractor .sampling_rate * data_args .max_length_seconds ))
326328
327329 model_input_name = feature_extractor .model_input_names [0 ]
328330
331+ def load_and_validate_audio (sample , feature_extractor , subsample : bool = False , max_length : float = None ):
332+ """
333+ Open audio via soundfile, downmix to mono if needed, validate sample rate,
334+ and optionally apply random subsampling.
335+ """
336+ path = sample .get ("path" )
337+ wav , sr = None , None
338+
339+ if isinstance (path , str ):
340+ try :
341+ wav , sr = sf .read (path , dtype = "float32" , always_2d = False )
342+ except Exception :
343+ wav , sr = None , None
344+
345+ if wav is None :
346+ raw = sample .get ("bytes" )
347+ if not raw :
348+ raise RuntimeError (f"Cannot open audio sample: { sample } " )
349+ fileobj = io .BytesIO (raw )
350+ wav , sr = sf .read (fileobj , dtype = "float32" , always_2d = False )
351+
352+ if wav .ndim > 1 :
353+ wav = wav .mean (axis = 1 )
354+
355+ if sr != feature_extractor .sampling_rate :
356+ raise RuntimeError (f"Expected { feature_extractor .sampling_rate } Hz, but got { sr } Hz for { path } " )
357+
358+ if subsample and max_length is not None :
359+ wav = random_subsample (wav , max_length = max_length , sample_rate = sr )
360+
361+ return wav
362+
329363 def train_transforms (batch ):
330364 """Apply train_transforms across a batch."""
331- subsampled_wavs = []
332-
333- for audio in batch [data_args .audio_column_name ]:
334- wav = random_subsample (
335- audio ["array" ], max_length = data_args .max_length_seconds , sample_rate = feature_extractor .sampling_rate
336- )
337- subsampled_wavs .append (wav )
365+ subsampled_wavs = [
366+ load_and_validate_audio (sample , feature_extractor , subsample = True , max_length = data_args .max_length_seconds )
367+ for sample in batch [data_args .audio_column_name ]
368+ ]
338369 inputs = feature_extractor (
339370 subsampled_wavs ,
340371 max_length = max_length ,
341372 sampling_rate = feature_extractor .sampling_rate ,
342373 padding = "max_length" ,
343374 truncation = True ,
344375 )
345- output_batch = { model_input_name : inputs . get ( model_input_name )}
346- output_batch [ "labels" ] = list ( batch [ data_args . label_column_name ])
347-
348- return output_batch
376+ return {
377+ model_input_name : inputs . get ( model_input_name ),
378+ "labels" : list ( batch [ data_args . label_column_name ]),
379+ }
349380
350381 def val_transforms (batch ):
351382 """Apply val_transforms across a batch."""
352- wavs = [audio ["array" ] for audio in batch [data_args .audio_column_name ]]
383+ wavs = [
384+ load_and_validate_audio (sample , feature_extractor , subsample = False )
385+ for sample in batch [data_args .audio_column_name ]
386+ ]
353387 inputs = feature_extractor (
354388 wavs ,
355389 max_length = max_length ,
356390 sampling_rate = feature_extractor .sampling_rate ,
357391 padding = "max_length" ,
358392 truncation = True ,
359393 )
360- output_batch = { model_input_name : inputs . get ( model_input_name )}
361- output_batch [ "labels" ] = list ( batch [ data_args . label_column_name ])
362-
363- return output_batch
394+ return {
395+ model_input_name : inputs . get ( model_input_name ),
396+ "labels" : list ( batch [ data_args . label_column_name ]),
397+ }
364398
365399 # Prepare label mappings.
366400 # We'll include these in the model's config to get human readable labels in the Inference API.
0 commit comments