Skip to content

Commit 4df7a32

Browse files
Merge pull request #1115 from misrasaurabh1/codeflash/optimize-Batch.remove_by_indices-m7dwflwc
⚡️ Speed up method `Batch.remove_by_indices` by 38%
2 parents 9cd771d + 1b20bcc commit 4df7a32

File tree

1 file changed

+12
-17
lines changed
  • inference/core/workflows/execution_engine/entities

1 file changed

+12
-17
lines changed

inference/core/workflows/execution_engine/entities/base.py

+12-17
Original file line numberDiff line numberDiff line change
@@ -132,11 +132,7 @@ def init(
132132

133133
return cls(content=content, indices=indices)
134134

135-
def __init__(
136-
self,
137-
content: List[B],
138-
indices: Optional[List[Tuple[int, ...]]],
139-
):
135+
def __init__(self, content: List[B], indices: Optional[List[Tuple[int, ...]]]):
140136
self._content = content
141137
self._indices = indices
142138

@@ -157,20 +153,19 @@ def __iter__(self) -> Iterator[B]:
157153
yield from self._content
158154

159155
def remove_by_indices(self, indices_to_remove: Set[tuple]) -> "Batch":
160-
content, new_indices = [], []
161-
for index, element in self.iter_with_indices():
162-
if index in indices_to_remove:
163-
continue
164-
content.append(element)
165-
new_indices.append(index)
166-
return Batch(
167-
content=content,
168-
indices=new_indices,
169-
)
156+
filtered_content = [
157+
element
158+
for index, element in zip(self._indices, self._content)
159+
if index not in indices_to_remove
160+
]
161+
filtered_indices = [
162+
index for index in self._indices if index not in indices_to_remove
163+
]
164+
165+
return Batch(content=filtered_content, indices=filtered_indices)
170166

171167
def iter_with_indices(self) -> Iterator[Tuple[Tuple[int, ...], B]]:
172-
for index, element in zip(self._indices, self._content):
173-
yield index, element
168+
return zip(self._indices, self._content)
174169

175170
def broadcast(self, n: int) -> "Batch":
176171
if n <= 0:

0 commit comments

Comments
 (0)