diff --git a/tensorflow_datasets/core/dataset_builder.py b/tensorflow_datasets/core/dataset_builder.py index c4d0befd769..a964e03a720 100644 --- a/tensorflow_datasets/core/dataset_builder.py +++ b/tensorflow_datasets/core/dataset_builder.py @@ -890,18 +890,22 @@ def as_data_source( ) chosen_format = file_format else: - chosen_format = suitable_formats.pop() - logging.info( - "Found random access formats: %s. Chose to use %s. Overriding file" - " format in the dataset info.", - ", ".join([f.name for f in suitable_formats]), - chosen_format, - ) + if info.file_format in suitable_formats: + chosen_format = info.file_format + else: + chosen_format = suitable_formats.pop() + logging.info( + "Found random access formats: %s. Chose to use %s. Overriding file" + " format in the dataset info.", + ", ".join([f.name for f in suitable_formats]), + chosen_format, + ) - # Change the dataset info to read from a random access format. - info.set_file_format( - chosen_format, override=True, override_if_initialized=True - ) + if info.file_format != chosen_format: + # Change the dataset info to read from a random access format. + info.set_file_format( + chosen_format, override=True, override_if_initialized=True + ) # Create a dataset for each of the given splits def build_single_data_source(split: str) -> Sequence[Any]: diff --git a/tensorflow_datasets/core/dataset_builder_test.py b/tensorflow_datasets/core/dataset_builder_test.py index d2d052c5c69..9c7d884f162 100644 --- a/tensorflow_datasets/core/dataset_builder_test.py +++ b/tensorflow_datasets/core/dataset_builder_test.py @@ -802,6 +802,31 @@ def test_load_as_data_source_alternative_file_format(self): assert len(data_source) == 10 assert data_source[0]["x"] == 28 + def test_load_as_data_source_with_multi_split_info(self): + data_dir = self.get_temp_dir() + builder = DummyDatasetWithConfigs( + data_dir=data_dir, + config="plus1", + file_format=file_adapters.FileFormat.ARRAY_RECORD, + ) + builder.download_and_prepare() + + # Make it a multi-split dataset. + multi_split_info = splits_lib.MultiSplitInfo( + name="train", split_infos=[builder.info.splits["train"]] + ) + builder.info.set_splits(splits_lib.SplitDict([multi_split_info])) + + self.assertIsInstance( + builder.info.splits["train"], splits_lib.MultiSplitInfo + ) + self.assertIsNone(builder.info.splits["train"].filename_template) + + data_source = builder.as_data_source(split="train") + self.assertIsNotNone(data_source) + self.assertLen(data_source, 20) + self.assertEqual(data_source[0]["x"], 7) + @parameterized.named_parameters( *[ {"file_format": file_format, "testcase_name": file_format.value} diff --git a/tensorflow_datasets/core/dataset_info.py b/tensorflow_datasets/core/dataset_info.py index 8fc52ab272f..f2eaf7ea0d5 100644 --- a/tensorflow_datasets/core/dataset_info.py +++ b/tensorflow_datasets/core/dataset_info.py @@ -538,6 +538,9 @@ def set_file_format( updated_split_infos = [] for split_info in self.splits.values(): if split_info.filename_template is None: + logging.warning( + "Split %s has no filename template, skipping.", split_info.name + ) continue updated_split_info = split_info.replace( filename_template=split_info.filename_template.replace(