diff --git a/luxonis_train/core/utils/infer_utils.py b/luxonis_train/core/utils/infer_utils.py index 3e57a210..d44e341e 100644 --- a/luxonis_train/core/utils/infer_utils.py +++ b/luxonis_train/core/utils/infer_utils.py @@ -1,6 +1,6 @@ from collections import defaultdict -from collections.abc import Iterable -from contextlib import suppress +from collections.abc import Generator, Iterable +from contextlib import contextmanager, suppress from pathlib import Path from typing import Any, Literal, cast @@ -8,6 +8,7 @@ import numpy as np import torch import torch.utils.data as torch_data +from lightning.pytorch.callbacks import BasePredictionWriter from loguru import logger from luxonis_ml.data import DatasetIterator, LuxonisDataset from luxonis_ml.typing import PathType @@ -145,14 +146,23 @@ def infer_from_loader( @type img_paths: list[Path] | None @param img_paths: The paths to the images. """ + if save_dir is not None: + save_dir = Path(save_dir) + writer = _VisualizationPredictionWriter(save_dir, img_paths) + with _temporary_callback(model.pl_trainer, writer): + model.pl_trainer.predict( + model.lightning_module, + loader, + return_predictions=False, + ) + return + predictions = model.pl_trainer.predict(model.lightning_module, loader) broken = False if predictions is None: # pragma: no cover return - counter = Counter() - for outputs in predictions: if broken: # pragma: no cover break @@ -161,30 +171,85 @@ def infer_from_loader( renders = process_visualizations(visualizations) batch_size = len(next(iter(renders.values()))) for i in range(batch_size): - if img_paths is not None: - idx = counter() for (node_name, viz_name), visualizations in renders.items(): viz = visualizations[i] - if save_dir is not None: - save_dir = Path(save_dir) - if img_paths is not None: - img_path = Path(img_paths[idx]) - name = f"{img_path.stem}_{node_name}_{viz_name}" - else: - name = f"{node_name}_{viz_name}_{counter()}" - name = name.replace("/", "-") - save_path = save_dir / f"{name}.png" - cv2.imwrite(str(save_path), viz) - else: - cv2.imshow(f"{node_name}/{viz_name}", viz) - - if not save_dir and window_closed(): # pragma: no cover + cv2.imshow(f"{node_name}/{viz_name}", viz) + + if window_closed(): # pragma: no cover broken = True break - if save_dir is None: # pragma: no cover - with suppress(cv2.error): # type: ignore - cv2.destroyAllWindows() + with suppress(cv2.error): # pragma: no cover + cv2.destroyAllWindows() + + +class _VisualizationPredictionWriter(BasePredictionWriter): + def __init__( + self, + save_dir: Path, + img_paths: list[PathType] | None = None, + ) -> None: + super().__init__(write_interval="batch") + self.save_dir = save_dir + self.img_paths = img_paths + self.counter = Counter() + + def write_on_batch_end( + self, + trainer: Any, + pl_module: Any, + prediction: Any, + batch_indices: Any, + batch: Any, + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + del trainer, pl_module, batch_indices, batch, batch_idx, dataloader_idx + + assert isinstance(prediction, LuxonisOutput) + renders = process_visualizations(prediction.visualizations) + _save_renders_batch( + renders, + self.save_dir, + self.counter, + self.img_paths, + ) + + +@contextmanager +def _temporary_callback( + trainer: Any, callback: Any +) -> Generator[None, None, None]: + trainer.callbacks.append(callback) + try: + yield + finally: + with suppress(ValueError): + trainer.callbacks.remove(callback) + + +def _save_renders_batch( + renders: dict[tuple[str, str], list[np.ndarray]], + save_dir: Path, + counter: Counter, + img_paths: list[PathType] | None = None, +) -> None: + if not renders: + return + + batch_size = len(next(iter(renders.values()))) + for i in range(batch_size): + img_path: Path | None = None + if img_paths is not None: + img_path = Path(img_paths[counter()]) + for (node_name, viz_name), visualizations in renders.items(): + viz = visualizations[i] + if img_path is not None: + name = f"{img_path.stem}_{node_name}_{viz_name}" + else: + name = f"{node_name}_{viz_name}_{counter()}" + name = name.replace("/", "-") + cv2.imwrite(str(save_dir / f"{name}.png"), viz) def create_loader_from_directory( diff --git a/tests/unittests/test_utils/test_infer_utils.py b/tests/unittests/test_utils/test_infer_utils.py new file mode 100644 index 00000000..bf1c69ae --- /dev/null +++ b/tests/unittests/test_utils/test_infer_utils.py @@ -0,0 +1,89 @@ +from pathlib import Path +from types import SimpleNamespace +from typing import Any + +import pytest +import torch +from luxonis_ml.typing import PathType + +from luxonis_train.core.utils.infer_utils import infer_from_loader +from luxonis_train.lightning import LuxonisOutput + + +class _MockTrainer: + def __init__(self, prediction: LuxonisOutput, callbacks: list[object]): + self._prediction = prediction + self.callbacks = callbacks + + def predict( + self, + lightning_module: object, + loader: list[object], + return_predictions: bool = True, + ) -> list[LuxonisOutput] | None: + if return_predictions: + return [self._prediction] + + for batch_idx, _ in enumerate(loader): + for callback in list(self.callbacks): + write_on_batch_end = getattr( + callback, "write_on_batch_end", None + ) + if write_on_batch_end is None: + continue + write_on_batch_end( + self, + lightning_module, + self._prediction, + None, + None, + batch_idx, + 0, + ) + return None + + +def test_infer_from_loader_temporary_callback_does_not_leak( + monkeypatch: pytest.MonkeyPatch, tmp_path: Path +): + saved_paths: list[str] = [] + monkeypatch.setattr( + "luxonis_train.core.utils.infer_utils.cv2.imwrite", + lambda path, _: saved_paths.append(Path(path).name) or True, + ) + + prediction = LuxonisOutput( + outputs={}, + losses={}, + visualizations={ + "DiscSubNetHead": { + "SegmentationVisualizer": torch.zeros((1, 3, 4, 4)) + } + }, + ) + existing_callback = object() + trainer = _MockTrainer(prediction, [existing_callback]) + model: Any = SimpleNamespace( + pl_trainer=trainer, + lightning_module=object(), + ) + loader: Any = [object()] + img_paths: list[PathType] = [tmp_path / "first.png"] + + infer_from_loader(model, loader, tmp_path, img_paths) + + assert trainer.callbacks == [existing_callback] + assert saved_paths == ["first_DiscSubNetHead_SegmentationVisualizer.png"] + + trainer.predict(model.lightning_module, loader, return_predictions=False) + + assert trainer.callbacks == [existing_callback] + assert saved_paths == ["first_DiscSubNetHead_SegmentationVisualizer.png"] + + infer_from_loader(model, loader, tmp_path, img_paths) + + assert trainer.callbacks == [existing_callback] + assert saved_paths == [ + "first_DiscSubNetHead_SegmentationVisualizer.png", + "first_DiscSubNetHead_SegmentationVisualizer.png", + ]