Skip to content

Commit ad31a88

Browse files
authored
Apply sampling_rate if specified (#910)
1 parent 56dbbc8 commit ad31a88

File tree

8 files changed

+151
-3
lines changed

8 files changed

+151
-3
lines changed
17.9 KB
Binary file not shown.
23.2 KB
Binary file not shown.
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
{
2+
"@context": {
3+
"@language": "en",
4+
"@vocab": "https://schema.org/",
5+
"citeAs": "cr:citeAs",
6+
"column": "cr:column",
7+
"conformsTo": "dct:conformsTo",
8+
"cr": "http://mlcommons.org/croissant/",
9+
"rai": "http://mlcommons.org/croissant/RAI/",
10+
"data": {
11+
"@id": "cr:data",
12+
"@type": "@json"
13+
},
14+
"dataType": {
15+
"@id": "cr:dataType",
16+
"@type": "@vocab"
17+
},
18+
"dct": "http://purl.org/dc/terms/",
19+
"examples": {
20+
"@id": "cr:examples",
21+
"@type": "@json"
22+
},
23+
"extract": "cr:extract",
24+
"field": "cr:field",
25+
"fileProperty": "cr:fileProperty",
26+
"fileObject": "cr:fileObject",
27+
"fileSet": "cr:fileSet",
28+
"format": "cr:format",
29+
"includes": "cr:includes",
30+
"isLiveDataset": "cr:isLiveDataset",
31+
"jsonPath": "cr:jsonPath",
32+
"key": "cr:key",
33+
"md5": "cr:md5",
34+
"parentField": "cr:parentField",
35+
"path": "cr:path",
36+
"recordSet": "cr:recordSet",
37+
"references": "cr:references",
38+
"regex": "cr:regex",
39+
"repeated": "cr:repeated",
40+
"replace": "cr:replace",
41+
"sc": "https://schema.org/",
42+
"samplingRate": "cr:samplingRate",
43+
"separator": "cr:separator",
44+
"source": "cr:source",
45+
"subField": "cr:subField",
46+
"transform": "cr:transform"
47+
},
48+
"@type": "sc:Dataset",
49+
"name": "audio_test",
50+
"description": "This is the basic test case for audio files",
51+
"conformsTo": "http://mlcommons.org/croissant/1.1",
52+
"url": "None",
53+
"distribution": [
54+
{
55+
"@type": "cr:FileSet",
56+
"@id": "files",
57+
"name": "files",
58+
"encodingFormat": "audio/mpeg",
59+
"includes": "data/*.mp3"
60+
}
61+
],
62+
"recordSet": [
63+
{
64+
"@type": "cr:RecordSet",
65+
"@id": "records",
66+
"name": "records",
67+
"description": "These are the records.",
68+
"field": [
69+
{
70+
"@type": "cr:Field",
71+
"@id": "records/audio",
72+
"name": "audio",
73+
"description": "These are the sounds.",
74+
"dataType": "sc:AudioObject",
75+
"source": {
76+
"fileSet": {
77+
"@id": "files"
78+
},
79+
"extract": {
80+
"fileProperty": "content"
81+
},
82+
"samplingRate": 22050
83+
}
84+
}
85+
]
86+
}
87+
]
88+
}
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
{"audio": "(array([-2.8619270e-13, -1.7014803e-13, 2.7065091e-14, ...,\n -6.4091455e-06, -3.7976279e-06, 2.7510678e-06],\n shape=(25872,), dtype=float32), 22050)"}
2+
{"audio": "(array([5.8726583e-14, 1.3397688e-13, 2.2199205e-13, ..., 4.2678180e-04,\n 1.9029720e-04, 2.7079385e-04], shape=(32928,), dtype=float32), 22050)"}

python/mlcroissant/mlcroissant/_src/operation_graph/operations/field.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ def _apply_transform_fn(value: Any, transform: Transform, field: Field) -> Any:
6363
raise ValueError(f"`format` only applies to dates. Got {field.data_type}")
6464
elif transform.separator is not None:
6565
return value.split(transform.separator)
66+
elif transform.sampling_rate is not None:
67+
return deps.librosa.resample(y=value, target_sr=transform.sampling_rate)
6668
return value
6769

6870

@@ -96,8 +98,7 @@ def _cast_value(ctx: Context, value: Any, data_type: type | term.URIRef | None):
9698
else:
9799
raise ValueError(f"Type {type(value)} is not accepted for an image.")
98100
elif data_type == DataType.AUDIO_OBJECT:
99-
output = deps.librosa.load(io.BytesIO(value))
100-
return output
101+
return value
101102
elif data_type == DataType.BOUNDING_BOX: # pytype: disable=wrong-arg-types
102103
return bounding_box.parse(value)
103104
elif not isinstance(data_type, type):

python/mlcroissant/mlcroissant/_src/operation_graph/operations/read.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,27 @@ def _reading_method(
7575
return next(iter(reading_methods))
7676

7777

78+
def _get_sampling_rate(
79+
node: FileObject | FileSet, fields: tuple[Field, ...]
80+
) -> int | None:
81+
"""Retuns the sampling rate to use for an audio file, if specified.
82+
83+
If several sampling rates are used for the same audio file, an error is raised.
84+
"""
85+
sampling_rates: set[int] = set()
86+
for field in fields:
87+
if sr := field.source.sampling_rate:
88+
sampling_rates.add(sr)
89+
if len(sampling_rates) > 1:
90+
raise ValueError(
91+
f"Cannot read {node=}. The fields use several sampling rates:"
92+
f" {sampling_rates}. Reading the same FileObject/FileSet using different"
93+
" sampling rate is not possible. You can change the original sampling rate"
94+
" of an audio using a Transform operation."
95+
)
96+
return next(iter(sampling_rates)) if sampling_rates else None
97+
98+
7899
def _should_append_line_numbers(fields: tuple[Field, ...]) -> bool:
79100
"""Checks whether at least one field requires listing the line numbers."""
80101
for field in fields:
@@ -162,8 +183,13 @@ def _read_file_content(
162183
encoding_format == EncodingFormat.MP3
163184
or encoding_format == EncodingFormat.JPG
164185
):
186+
sampling_rate = _get_sampling_rate(self.node, self.fields)
187+
if sampling_rate:
188+
out = deps.librosa.load(file, sr=sampling_rate)
189+
else:
190+
out = deps.librosa.load(file)
165191
return pd.DataFrame({
166-
FileProperty.content: [file.read()],
192+
FileProperty.content: [out],
167193
})
168194
raise ValueError(
169195
f"None of the provided encoding formats: {encoding_format} for file"

python/mlcroissant/mlcroissant/_src/operation_graph/operations/read_test.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import pytest
1313

1414
from mlcroissant._src.core.path import Path
15+
from mlcroissant._src.operation_graph.operations.read import _get_sampling_rate
1516
from mlcroissant._src.operation_graph.operations.read import _read_arff_file
1617
from mlcroissant._src.operation_graph.operations.read import _reading_method
1718
from mlcroissant._src.operation_graph.operations.read import Read
@@ -46,6 +47,26 @@ def test_str_representation():
4647
assert str(operation) == "Read(file_object_name)"
4748

4849

50+
def test_get_sampling_rate():
51+
node = create_test_file_object()
52+
audio_field = create_test_field(source=Source(sampling_rate=3000))
53+
assert _get_sampling_rate(node=node, fields=(audio_field,)) == 3000
54+
55+
56+
def test_get_sampling_rate_with_value_error():
57+
node = create_test_file_object()
58+
audio_field_1 = create_test_field(source=Source(sampling_rate=2000))
59+
audio_field_2 = create_test_field(source=Source(sampling_rate=3000))
60+
with pytest.raises(
61+
ValueError,
62+
match=(
63+
r'Cannot read node=FileObject\(uuid="file_object_name"\). The fields use'
64+
" several sampling rates: {2000, 3000}"
65+
),
66+
):
67+
_get_sampling_rate(node=node, fields=(audio_field_1, audio_field_2))
68+
69+
4970
def test_reading_arff():
5071
filepath = io.StringIO(ARFF_CONTENT)
5172
actual_df = _read_arff_file(filepath)

python/mlcroissant/mlcroissant/_src/structure_graph/nodes/source.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,11 @@ class Transform(Node):
127127
input_types=[SDO.Text],
128128
url=constants.ML_COMMONS_REPLACE,
129129
)
130+
sampling_rate: int | None = mlc_dataclasses.jsonld_field(
131+
default=None,
132+
input_types=[SDO.Integer],
133+
url=constants.ML_COMMONS_SAMPLING_RATE,
134+
)
130135
separator: str | None = mlc_dataclasses.jsonld_field(
131136
default=None,
132137
input_types=[SDO.Text],
@@ -218,6 +223,11 @@ class Source(Node):
218223
input_types=[SDO.Text],
219224
url=constants.ML_COMMONS_FORMAT,
220225
)
226+
sampling_rate: int | None = mlc_dataclasses.jsonld_field(
227+
default=None,
228+
input_types=[SDO.Integer],
229+
url=constants.ML_COMMONS_SAMPLING_RATE,
230+
)
221231
transforms: list[Transform] = mlc_dataclasses.jsonld_field(
222232
cardinality="MANY",
223233
default_factory=list,

0 commit comments

Comments
 (0)