diff --git a/src/datasets/features/audio.py b/src/datasets/features/audio.py index a398b110da7..b22dc5aaf32 100644 --- a/src/datasets/features/audio.py +++ b/src/datasets/features/audio.py @@ -236,6 +236,7 @@ def cast_storage(self, storage: Union[pa.StringArray, pa.StructArray]) -> pa.Str The Arrow types that can be converted to the Audio pyarrow storage type are: - `pa.string()` - it must contain the "path" data + - `pa.large_string()` - it must contain the "path" data (will be cast to string if possible) - `pa.binary()` - it must contain the audio bytes - `pa.struct({"bytes": pa.binary()})` - `pa.struct({"path": pa.string()})` @@ -249,7 +250,9 @@ def cast_storage(self, storage: Union[pa.StringArray, pa.StructArray]) -> pa.Str `pa.StructArray`: Array in the Audio arrow storage type, that is `pa.struct({"bytes": pa.binary(), "path": pa.string()})` """ - if pa.types.is_string(storage.type): + if pa.types.is_string(storage.type) or pa.types.is_large_string(storage.type): + if pa.types.is_large_string(storage.type): + storage = array_cast(storage, pa.string()) bytes_array = pa.array([None] * len(storage), type=pa.binary()) storage = pa.StructArray.from_arrays([bytes_array, storage], ["bytes", "path"], mask=storage.is_null()) elif pa.types.is_large_binary(storage.type): diff --git a/src/datasets/features/pdf.py b/src/datasets/features/pdf.py index 756530554d4..b863e8770bf 100644 --- a/src/datasets/features/pdf.py +++ b/src/datasets/features/pdf.py @@ -186,6 +186,7 @@ def cast_storage(self, storage: Union[pa.StringArray, pa.StructArray, pa.ListArr The Arrow types that can be converted to the Pdf pyarrow storage type are: - `pa.string()` - it must contain the "path" data + - `pa.large_string()` - it must contain the "path" data (will be cast to string if possible) - `pa.binary()` - it must contain the image bytes - `pa.struct({"bytes": pa.binary()})` - `pa.struct({"path": pa.string()})` @@ -200,6 +201,15 @@ def cast_storage(self, storage: Union[pa.StringArray, pa.StructArray, pa.ListArr `pa.StructArray`: Array in the Pdf arrow storage type, that is `pa.struct({"bytes": pa.binary(), "path": pa.string()})`. """ + if pa.types.is_large_string(storage.type): + try: + storage = storage.cast(pa.string()) + except pa.ArrowInvalid as e: + raise ValueError( + f"Failed to cast large_string to string for Pdf feature. " + f"This can happen if string values exceed 2GB. " + f"Original error: {e}" + ) from e if pa.types.is_string(storage.type): bytes_array = pa.array([None] * len(storage), type=pa.binary()) storage = pa.StructArray.from_arrays([bytes_array, storage], ["bytes", "path"], mask=storage.is_null()) diff --git a/src/datasets/features/video.py b/src/datasets/features/video.py index cf1c19551ca..2681a547578 100644 --- a/src/datasets/features/video.py +++ b/src/datasets/features/video.py @@ -241,6 +241,7 @@ def cast_storage(self, storage: Union[pa.StringArray, pa.StructArray, pa.ListArr The Arrow types that can be converted to the Video pyarrow storage type are: - `pa.string()` - it must contain the "path" data + - `pa.large_string()` - it must contain the "path" data (will be cast to string if possible) - `pa.binary()` - it must contain the video bytes - `pa.struct({"bytes": pa.binary()})` - `pa.struct({"path": pa.string()})` @@ -255,6 +256,15 @@ def cast_storage(self, storage: Union[pa.StringArray, pa.StructArray, pa.ListArr `pa.StructArray`: Array in the Video arrow storage type, that is `pa.struct({"bytes": pa.binary(), "path": pa.string()})`. """ + if pa.types.is_large_string(storage.type): + try: + storage = storage.cast(pa.string()) + except pa.ArrowInvalid as e: + raise ValueError( + f"Failed to cast large_string to string for Video feature. " + f"This can happen if string values exceed 2GB. " + f"Original error: {e}" + ) from e if pa.types.is_string(storage.type): bytes_array = pa.array([None] * len(storage), type=pa.binary()) storage = pa.StructArray.from_arrays([bytes_array, storage], ["bytes", "path"], mask=storage.is_null()) diff --git a/tests/features/test_audio.py b/tests/features/test_audio.py index a6dbca799fe..7caa510c7ab 100644 --- a/tests/features/test_audio.py +++ b/tests/features/test_audio.py @@ -501,6 +501,20 @@ def test_resampling_after_loading_dataset_with_audio_feature_mp3(shared_datadir) assert samples.sample_rate == 16000 assert samples.data.shape == (2, 40124) +def test_cast_column_audio_from_csv_large_string(audio_file, tmp_path): + from datasets import Audio, load_dataset + + csv_path = tmp_path / "audio.csv" + csv_path.write_text(f"audio\n{audio_file}\n", encoding="utf-8") + + dset = load_dataset("csv", data_files=str(csv_path), split="train") + assert str(dset.features["audio"]) == "Value('large_string')" + + dset = dset.cast_column("audio", Audio(decode=False)) + + assert isinstance(dset.features["audio"], Audio) + item = dset[0]["audio"] + assert item["path"] == audio_file @require_torchcodec @pytest.mark.parametrize( diff --git a/tests/features/test_pdf.py b/tests/features/test_pdf.py index fe0b521c96c..7c319726b73 100644 --- a/tests/features/test_pdf.py +++ b/tests/features/test_pdf.py @@ -2,7 +2,7 @@ import pytest -from datasets import Dataset, Features, Pdf +from datasets import Dataset, Features, Pdf, load_dataset from ..utils import require_pdfplumber @@ -60,3 +60,18 @@ def test_dataset_with_pdf_feature(shared_datadir): item = dset[0] assert item.keys() == {"pdf"} assert isinstance(item["pdf"], pdfplumber.pdf.PDF) + +def test_cast_column_pdf_from_csv_large_string(shared_datadir, tmp_path): + pdf_path = str(shared_datadir / "test_pdf.pdf") + csv_path = tmp_path / "pdf.csv" + + csv_path.write_text(f"pdf\n{pdf_path}\n", encoding="utf-8") + + dset = load_dataset("csv", data_files=str(csv_path), split="train") + assert str(dset.features["pdf"]) == "Value('large_string')" + + dset = dset.cast_column("pdf", Pdf(decode=False)) + + assert isinstance(dset.features["pdf"], Pdf) + item = dset[0]["pdf"] + assert item["path"] == pdf_path diff --git a/tests/features/test_video.py b/tests/features/test_video.py index 131b01be6d2..29de50e9e8b 100644 --- a/tests/features/test_video.py +++ b/tests/features/test_video.py @@ -71,6 +71,21 @@ def test_dataset_with_video_feature(shared_datadir): assert item["video"].get_frame_at(0).data.shape == (3, 50, 66) assert isinstance(item["video"].get_frame_at(0).data, torch.Tensor) +def test_cast_column_video_from_csv_large_string(shared_datadir, tmp_path): + video_path = str(shared_datadir / "test_video_66x50.mov") + csv_path = tmp_path / "video.csv" + + csv_path.write_text(f"video\n{video_path}\n", encoding="utf-8") + + dset = load_dataset("csv", data_files=str(csv_path), split="train") + assert str(dset.features["video"]) == "Value('large_string')" + + dset = dset.cast_column("video", Video(decode=False)) + + assert isinstance(dset.features["video"], Video) + item = dset[0]["video"] + assert item["path"] == video_path + @require_torchcodec def test_dataset_with_video_map_and_formatted(shared_datadir):