Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/datasets/features/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,9 @@ def cast_storage(self, storage: Union[pa.StringArray, pa.StructArray]) -> pa.Str
`pa.StructArray`: Array in the Audio arrow storage type, that is
`pa.struct({"bytes": pa.binary(), "path": pa.string()})`
"""
if pa.types.is_string(storage.type):
if pa.types.is_string(storage.type) or pa.types.is_large_string(storage.type):
if pa.types.is_large_string(storage.type):
storage = array_cast(storage, pa.string())
Comment on lines +253 to +255
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can use the same code as in Image.cast_storage

Suggested change
if pa.types.is_string(storage.type) or pa.types.is_large_string(storage.type):
if pa.types.is_large_string(storage.type):
storage = array_cast(storage, pa.string())
if pa.types.is_large_string(storage.type):
try:
storage = storage.cast(pa.string())
except pa.ArrowInvalid as e:
raise ValueError(
f"Failed to cast large_string to string for Image feature. "
f"This can happen if string values exceed 2GB. "
f"Original error: {e}"
) from e
if pa.types.is_string(storage.type):

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright that's helpful

bytes_array = pa.array([None] * len(storage), type=pa.binary())
storage = pa.StructArray.from_arrays([bytes_array, storage], ["bytes", "path"], mask=storage.is_null())
elif pa.types.is_large_binary(storage.type):
Expand Down
35 changes: 35 additions & 0 deletions tests/features/test_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,41 @@ def test_resampling_after_loading_dataset_with_audio_feature_mp3(shared_datadir)
assert samples.sample_rate == 16000
assert samples.data.shape == (2, 40124)

def test_cast_column_audio_from_csv_large_string(tmp_path):
import wave
import struct
import math

from datasets import Audio, load_dataset

audio_path = tmp_path / "example.wav"
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is a audio_path fixture you can use already

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright

csv_path = tmp_path / "audio.csv"

sr = 16000
duration = 0.25
freq = 440.0
samples = int(sr * duration)

with wave.open(str(audio_path), "w") as wf:
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(sr)
frames = bytearray()
for i in range(samples):
x = int(16000 * math.sin(2 * math.pi * freq * i / sr))
frames.extend(struct.pack("<h", x))
wf.writeframes(frames)

csv_path.write_text(f"audio\n{audio_path}\n", encoding="utf-8")

dset = load_dataset("csv", data_files=str(csv_path), split="train")
assert str(dset.features["audio"]) == "Value('large_string')"

dset = dset.cast_column("audio", Audio(decode=False))

assert isinstance(dset.features["audio"], Audio)
item = dset[0]["audio"]
assert item["path"] == str(audio_path)

@require_torchcodec
@pytest.mark.parametrize(
Expand Down
Loading