File tree Expand file tree Collapse file tree 3 files changed +43
-11
lines changed
Expand file tree Collapse file tree 3 files changed +43
-11
lines changed Original file line number Diff line number Diff line change @@ -890,18 +890,22 @@ def as_data_source(
890890 )
891891 chosen_format = file_format
892892 else :
893- chosen_format = suitable_formats .pop ()
894- logging .info (
895- "Found random access formats: %s. Chose to use %s. Overriding file"
896- " format in the dataset info." ,
897- ", " .join ([f .name for f in suitable_formats ]),
898- chosen_format ,
899- )
893+ if info .file_format in suitable_formats :
894+ chosen_format = info .file_format
895+ else :
896+ chosen_format = suitable_formats .pop ()
897+ logging .info (
898+ "Found random access formats: %s. Chose to use %s. Overriding file"
899+ " format in the dataset info." ,
900+ ", " .join ([f .name for f in suitable_formats ]),
901+ chosen_format ,
902+ )
900903
901- # Change the dataset info to read from a random access format.
902- info .set_file_format (
903- chosen_format , override = True , override_if_initialized = True
904- )
904+ if info .file_format != chosen_format :
905+ # Change the dataset info to read from a random access format.
906+ info .set_file_format (
907+ chosen_format , override = True , override_if_initialized = True
908+ )
905909
906910 # Create a dataset for each of the given splits
907911 def build_single_data_source (split : str ) -> Sequence [Any ]:
Original file line number Diff line number Diff line change @@ -802,6 +802,31 @@ def test_load_as_data_source_alternative_file_format(self):
802802 assert len (data_source ) == 10
803803 assert data_source [0 ]["x" ] == 28
804804
805+ def test_load_as_data_source_with_multi_split_info (self ):
806+ data_dir = self .get_temp_dir ()
807+ builder = DummyDatasetWithConfigs (
808+ data_dir = data_dir ,
809+ config = "plus1" ,
810+ file_format = file_adapters .FileFormat .ARRAY_RECORD ,
811+ )
812+ builder .download_and_prepare ()
813+
814+ # Make it a multi-split dataset.
815+ multi_split_info = splits_lib .MultiSplitInfo (
816+ name = "train" , split_infos = [builder .info .splits ["train" ]]
817+ )
818+ builder .info .set_splits (splits_lib .SplitDict ([multi_split_info ]))
819+
820+ self .assertIsInstance (
821+ builder .info .splits ["train" ], splits_lib .MultiSplitInfo
822+ )
823+ self .assertIsNone (builder .info .splits ["train" ].filename_template )
824+
825+ data_source = builder .as_data_source (split = "train" )
826+ self .assertIsNotNone (data_source )
827+ self .assertLen (data_source , 20 )
828+ self .assertEqual (data_source [0 ]["x" ], 7 )
829+
805830 @parameterized .named_parameters (
806831 * [
807832 {"file_format" : file_format , "testcase_name" : file_format .value }
Original file line number Diff line number Diff line change @@ -538,6 +538,9 @@ def set_file_format(
538538 updated_split_infos = []
539539 for split_info in self .splits .values ():
540540 if split_info .filename_template is None :
541+ logging .warning (
542+ "Split %s has no filename template, skipping." , split_info .name
543+ )
541544 continue
542545 updated_split_info = split_info .replace (
543546 filename_template = split_info .filename_template .replace (
You can’t perform that action at this time.
0 commit comments