1818from __future__ import annotations
1919
2020import abc
21- from collections .abc import Iterable
21+ from collections .abc import Iterable , Sequence
2222import dataclasses
2323import functools
2424import itertools
@@ -123,7 +123,7 @@ def __post_init__(self):
123123 def get_available_shards (
124124 self ,
125125 data_dir : epath .Path | None = None ,
126- file_format : file_adapters .FileFormat | None = None ,
126+ file_format : str | file_adapters .FileFormat | None = None ,
127127 strict_matching : bool = True ,
128128 ) -> list [epath .Path ]:
129129 """Returns the list of shards that are present in the data dir.
@@ -140,6 +140,7 @@ def get_available_shards(
140140 """
141141 if filename_template := self .filename_template :
142142 if file_format :
143+ file_format = file_adapters .FileFormat .from_value (file_format )
143144 filename_template = filename_template .replace (
144145 filetype_suffix = file_format .file_suffix
145146 )
@@ -250,7 +251,9 @@ def replace(self, **kwargs: Any) -> SplitInfo:
250251 """Returns a copy of the `SplitInfo` with updated attributes."""
251252 return dataclasses .replace (self , ** kwargs )
252253
253- def file_spec (self , file_format : file_adapters .FileFormat ) -> str :
254+ def file_spec (
255+ self , file_format : str | file_adapters .FileFormat
256+ ) -> str | None :
254257 """Returns the file spec of the split for the given file format.
255258
256259 A file spec is the full path with sharded notation, e.g.,
@@ -259,6 +262,7 @@ def file_spec(self, file_format: file_adapters.FileFormat) -> str:
259262 Args:
260263 file_format: the file format for which to create the file spec for.
261264 """
265+ file_format = file_adapters .FileFormat .from_value (file_format )
262266 if filename_template := self .filename_template :
263267 if filename_template .filetype_suffix != file_format .file_suffix :
264268 raise ValueError (
@@ -268,9 +272,7 @@ def file_spec(self, file_format: file_adapters.FileFormat) -> str:
268272 return filename_template .sharded_filepaths_pattern (
269273 num_shards = self .num_shards
270274 )
271- raise ValueError (
272- f'Could not get filename template for split from split info: { self } .'
273- )
275+ return None
274276
275277
276278@dataclasses .dataclass (eq = False , frozen = True )
@@ -425,7 +427,7 @@ def __repr__(self) -> str:
425427if typing .TYPE_CHECKING :
426428 # For type checking, `tfds.Split` is an alias for `str` with additional
427429 # `.TRAIN`, `.TEST`,... attributes. All strings are valid split type.
428- Split = Union [ Split , str ]
430+ Split = Split | str
429431
430432
431433class SplitDict (utils .NonMutableDict [str , SplitInfo ]):
@@ -438,7 +440,7 @@ def __init__(
438440 # TODO(b/216470058): remove this parameter
439441 dataset_name : str | None = None , # deprecated, please don't use
440442 ):
441- super (SplitDict , self ).__init__ (
443+ super ().__init__ (
442444 {split_info .name : split_info for split_info in split_infos },
443445 error_msg = 'Split {key} already present' ,
444446 )
@@ -457,7 +459,7 @@ def __getitem__(self, key) -> SplitInfo | SubSplitInfo:
457459 )
458460 # 1st case: The key exists: `info.splits['train']`
459461 elif str (key ) in self .keys ():
460- return super (SplitDict , self ).__getitem__ (str (key ))
462+ return super ().__getitem__ (str (key ))
461463 # 2nd case: Uses instructions: `info.splits['train[50%]']`
462464 else :
463465 instructions = _make_file_instructions (
@@ -543,7 +545,7 @@ def _file_instructions_for_split(
543545
544546
545547def _make_file_instructions (
546- split_infos : list [SplitInfo ],
548+ split_infos : Sequence [SplitInfo ],
547549 instruction : SplitArg ,
548550) -> list [shard_utils .FileInstruction ]:
549551 """Returns file instructions by applying the given instruction on the given splits.
@@ -587,7 +589,7 @@ class AbstractSplit(abc.ABC):
587589 """
588590
589591 @classmethod
590- def from_spec (cls , spec : SplitArg ) -> ' AbstractSplit' :
592+ def from_spec (cls , spec : SplitArg ) -> AbstractSplit :
591593 """Creates a ReadInstruction instance out of a string spec.
592594
593595 Args:
@@ -632,7 +634,7 @@ def to_absolute(self, split_infos) -> list[_AbsoluteInstruction]:
632634 """
633635 raise NotImplementedError
634636
635- def __add__ (self , other : Union [ str , ' AbstractSplit' ] ) -> ' AbstractSplit' :
637+ def __add__ (self , other : str | AbstractSplit ) -> AbstractSplit :
636638 """Sum of 2 splits."""
637639 if not isinstance (other , (str , AbstractSplit )):
638640 raise TypeError (f'Adding split { self !r} with non-split value: { other !r} ' )
0 commit comments