@@ -983,7 +983,7 @@ def __cat(self, batches: Sequence[dict | Self], lens: list[int]) -> None:
983983 self .__dict__ [key ][sum_lens [i ] : sum_lens [i + 1 ]] = value
984984
985985 def cat_ (self , batches : BatchProtocol | Sequence [dict | BatchProtocol ]) -> None :
986- if isinstance (batches , BatchProtocol | dict ):
986+ if isinstance (batches , Batch | dict ):
987987 batches = [batches ]
988988 # check input format
989989 batch_list = []
@@ -1069,7 +1069,7 @@ def stack_(self, batches: Sequence[dict | BatchProtocol], axis: int = 0) -> None
10691069 {
10701070 batch_key
10711071 for batch_key , obj in batch .items ()
1072- if not (isinstance (obj , BatchProtocol ) and len (obj .get_keys ()) == 0 )
1072+ if not (isinstance (obj , Batch ) and len (obj .get_keys ()) == 0 )
10731073 }
10741074 for batch in batches
10751075 ]
@@ -1080,7 +1080,7 @@ def stack_(self, batches: Sequence[dict | BatchProtocol], axis: int = 0) -> None
10801080 if all (isinstance (element , torch .Tensor ) for element in value ):
10811081 self .__dict__ [shared_key ] = torch .stack (value , axis )
10821082 # third often
1083- elif all (isinstance (element , BatchProtocol | dict ) for element in value ):
1083+ elif all (isinstance (element , Batch | dict ) for element in value ):
10841084 self .__dict__ [shared_key ] = Batch .stack (value , axis )
10851085 else : # most often case is np.ndarray
10861086 try :
@@ -1114,7 +1114,7 @@ def stack_(self, batches: Sequence[dict | BatchProtocol], axis: int = 0) -> None
11141114 value = batch .get (key )
11151115 # TODO: fix code/annotations s.t. the ignores can be removed
11161116 if (
1117- isinstance (value , BatchProtocol ) # type: ignore
1117+ isinstance (value , Batch ) # type: ignore
11181118 and len (value .get_keys ()) == 0 # type: ignore
11191119 ):
11201120 continue # type: ignore
@@ -1288,7 +1288,7 @@ def set_array_at_key(
12881288 ) from exception
12891289 else :
12901290 existing_entry = self [key ]
1291- if isinstance (existing_entry , BatchProtocol ):
1291+ if isinstance (existing_entry , Batch ):
12921292 raise ValueError (
12931293 f"Cannot set sequence at key { key } because it is a nested batch, "
12941294 f"can only set a subsequence of an array." ,
@@ -1312,7 +1312,7 @@ def hasnull(self) -> bool:
13121312
13131313 def is_any_true (boolean_batch : BatchProtocol ) -> bool :
13141314 for val in boolean_batch .values ():
1315- if isinstance (val , BatchProtocol ):
1315+ if isinstance (val , Batch ):
13161316 if is_any_true (val ):
13171317 return True
13181318 else :
@@ -1375,7 +1375,7 @@ def _apply_batch_values_func_recursively(
13751375 """
13761376 result = batch if inplace else deepcopy (batch )
13771377 for key , val in batch .__dict__ .items ():
1378- if isinstance (val , BatchProtocol ):
1378+ if isinstance (val , Batch ):
13791379 result [key ] = _apply_batch_values_func_recursively (val , values_transform , inplace = False )
13801380 else :
13811381 result [key ] = values_transform (val )
0 commit comments