Skip to content

Commit cface66

Browse files
committed
Cleanliness: _save_renders inside InferenceSaveWriter
1 parent a262b41 commit cface66

1 file changed

Lines changed: 20 additions & 24 deletions

File tree

luxonis_train/core/utils/infer_utils.py

Lines changed: 20 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -186,29 +186,6 @@ def infer_from_loader(
186186
cv2.destroyAllWindows()
187187

188188

189-
def save_renders(
190-
renders: dict[tuple[str, str], list[np.ndarray]],
191-
save_dir: Path,
192-
counter: Counter,
193-
img_paths: list[PathType] | None = None,
194-
) -> None:
195-
"""Persist a rendered batch to disk."""
196-
batch_size = len(next(iter(renders.values())))
197-
198-
for i in range(batch_size):
199-
if img_paths is not None:
200-
idx = counter()
201-
img_path = Path(img_paths[idx])
202-
for (node_name, viz_name), visualizations in renders.items():
203-
viz = visualizations[i]
204-
if img_paths is not None:
205-
name = f"{img_path.stem}_{node_name}_{viz_name}"
206-
else:
207-
name = f"{node_name}_{viz_name}_{counter()}"
208-
name = name.replace("/", "-")
209-
cv2.imwrite(str(save_dir / f"{name}.png"), viz)
210-
211-
212189
class InferenceSaveWriter(BasePredictionWriter):
213190
"""Writes rendered inference batches as soon as they are
214191
predicted."""
@@ -239,7 +216,7 @@ def write_on_batch_end(
239216
if not renders:
240217
return
241218

242-
save_renders(renders, self.save_dir, self.counter, self.img_paths)
219+
self._save_renders(renders)
243220

244221
def write_on_epoch_end(
245222
self,
@@ -250,6 +227,25 @@ def write_on_epoch_end(
250227
) -> None:
251228
del trainer, pl_module, predictions, batch_indices
252229

230+
def _save_renders(
231+
self, renders: dict[tuple[str, str], list[np.ndarray]]
232+
) -> None:
233+
"""Persist a rendered batch to disk."""
234+
batch_size = len(next(iter(renders.values())))
235+
236+
for i in range(batch_size):
237+
if self.img_paths is not None:
238+
idx = self.counter()
239+
img_path = Path(self.img_paths[idx])
240+
for (node_name, viz_name), visualizations in renders.items():
241+
viz = visualizations[i]
242+
if self.img_paths is not None:
243+
name = f"{img_path.stem}_{node_name}_{viz_name}"
244+
else:
245+
name = f"{node_name}_{viz_name}_{self.counter()}"
246+
name = name.replace("/", "-")
247+
cv2.imwrite(str(self.save_dir / f"{name}.png"), viz)
248+
253249

254250
def create_loader_from_directory(
255251
img_paths: Iterable[PathType],

0 commit comments

Comments
 (0)