Skip to content

Commit 7f40717

Browse files
author
The TensorFlow Datasets Authors
committed
Fix incompatibility of tfds.Builder.as_data_source with MultiSplitInfo.
PiperOrigin-RevId: 858633111
1 parent 982a819 commit 7f40717

File tree

3 files changed

+43
-11
lines changed

3 files changed

+43
-11
lines changed

tensorflow_datasets/core/dataset_builder.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -890,18 +890,22 @@ def as_data_source(
890890
)
891891
chosen_format = file_format
892892
else:
893-
chosen_format = suitable_formats.pop()
894-
logging.info(
895-
"Found random access formats: %s. Chose to use %s. Overriding file"
896-
" format in the dataset info.",
897-
", ".join([f.name for f in suitable_formats]),
898-
chosen_format,
899-
)
893+
if info.file_format in suitable_formats:
894+
chosen_format = info.file_format
895+
else:
896+
chosen_format = suitable_formats.pop()
897+
logging.info(
898+
"Found random access formats: %s. Chose to use %s. Overriding file"
899+
" format in the dataset info.",
900+
", ".join([f.name for f in suitable_formats]),
901+
chosen_format,
902+
)
900903

901-
# Change the dataset info to read from a random access format.
902-
info.set_file_format(
903-
chosen_format, override=True, override_if_initialized=True
904-
)
904+
if info.file_format != chosen_format:
905+
# Change the dataset info to read from a random access format.
906+
info.set_file_format(
907+
chosen_format, override=True, override_if_initialized=True
908+
)
905909

906910
# Create a dataset for each of the given splits
907911
def build_single_data_source(split: str) -> Sequence[Any]:

tensorflow_datasets/core/dataset_builder_test.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -802,6 +802,31 @@ def test_load_as_data_source_alternative_file_format(self):
802802
assert len(data_source) == 10
803803
assert data_source[0]["x"] == 28
804804

805+
def test_load_as_data_source_with_multi_split_info(self):
806+
data_dir = self.get_temp_dir()
807+
builder = DummyDatasetWithConfigs(
808+
data_dir=data_dir,
809+
config="plus1",
810+
file_format=file_adapters.FileFormat.ARRAY_RECORD,
811+
)
812+
builder.download_and_prepare()
813+
814+
# Make it a multi-split dataset.
815+
multi_split_info = splits_lib.MultiSplitInfo(
816+
name="train", split_infos=[builder.info.splits["train"]]
817+
)
818+
builder.info.set_splits(splits_lib.SplitDict([multi_split_info]))
819+
820+
self.assertIsInstance(
821+
builder.info.splits["train"], splits_lib.MultiSplitInfo
822+
)
823+
self.assertIsNone(builder.info.splits["train"].filename_template)
824+
825+
data_source = builder.as_data_source(split="train")
826+
self.assertIsNotNone(data_source)
827+
self.assertLen(data_source, 20)
828+
self.assertEqual(data_source[0]["x"], 7)
829+
805830
@parameterized.named_parameters(
806831
*[
807832
{"file_format": file_format, "testcase_name": file_format.value}

tensorflow_datasets/core/dataset_info.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,9 @@ def set_file_format(
538538
updated_split_infos = []
539539
for split_info in self.splits.values():
540540
if split_info.filename_template is None:
541+
logging.warning(
542+
"Split %s has no filename template, skipping.", split_info.name
543+
)
541544
continue
542545
updated_split_info = split_info.replace(
543546
filename_template=split_info.filename_template.replace(

0 commit comments

Comments
 (0)