@@ -132,11 +132,7 @@ def init(
132
132
133
133
return cls (content = content , indices = indices )
134
134
135
- def __init__ (
136
- self ,
137
- content : List [B ],
138
- indices : Optional [List [Tuple [int , ...]]],
139
- ):
135
+ def __init__ (self , content : List [B ], indices : Optional [List [Tuple [int , ...]]]):
140
136
self ._content = content
141
137
self ._indices = indices
142
138
@@ -157,20 +153,19 @@ def __iter__(self) -> Iterator[B]:
157
153
yield from self ._content
158
154
159
155
def remove_by_indices (self , indices_to_remove : Set [tuple ]) -> "Batch" :
160
- content , new_indices = [], []
161
- for index , element in self . iter_with_indices ():
162
- if index in indices_to_remove :
163
- continue
164
- content . append ( element )
165
- new_indices . append ( index )
166
- return Batch (
167
- content = content ,
168
- indices = new_indices ,
169
- )
156
+ filtered_content = [
157
+ element
158
+ for index , element in zip ( self . _indices , self . _content )
159
+ if index not in indices_to_remove
160
+ ]
161
+ filtered_indices = [
162
+ index for index in self . _indices if index not in indices_to_remove
163
+ ]
164
+
165
+ return Batch ( content = filtered_content , indices = filtered_indices )
170
166
171
167
def iter_with_indices (self ) -> Iterator [Tuple [Tuple [int , ...], B ]]:
172
- for index , element in zip (self ._indices , self ._content ):
173
- yield index , element
168
+ return zip (self ._indices , self ._content )
174
169
175
170
def broadcast (self , n : int ) -> "Batch" :
176
171
if n <= 0 :
0 commit comments