Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 88 additions & 23 deletions luxonis_train/core/utils/infer_utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from collections import defaultdict
from collections.abc import Iterable
from contextlib import suppress
from collections.abc import Generator, Iterable
from contextlib import contextmanager, suppress
from pathlib import Path
from typing import Any, Literal, cast

import cv2
import numpy as np
import torch
import torch.utils.data as torch_data
from lightning.pytorch.callbacks import BasePredictionWriter
from loguru import logger
from luxonis_ml.data import DatasetIterator, LuxonisDataset
from luxonis_ml.typing import PathType
Expand Down Expand Up @@ -145,14 +146,23 @@ def infer_from_loader(
@type img_paths: list[Path] | None
@param img_paths: The paths to the images.
"""
if save_dir is not None:
save_dir = Path(save_dir)
writer = _VisualizationPredictionWriter(save_dir, img_paths)
with _temporary_callback(model.pl_trainer, writer):
model.pl_trainer.predict(
model.lightning_module,
loader,
return_predictions=False,
)
return

predictions = model.pl_trainer.predict(model.lightning_module, loader)

broken = False
if predictions is None: # pragma: no cover
return

counter = Counter()

for outputs in predictions:
if broken: # pragma: no cover
break
Expand All @@ -161,30 +171,85 @@ def infer_from_loader(
renders = process_visualizations(visualizations)
batch_size = len(next(iter(renders.values())))
for i in range(batch_size):
if img_paths is not None:
idx = counter()
for (node_name, viz_name), visualizations in renders.items():
viz = visualizations[i]
if save_dir is not None:
save_dir = Path(save_dir)
if img_paths is not None:
img_path = Path(img_paths[idx])
name = f"{img_path.stem}_{node_name}_{viz_name}"
else:
name = f"{node_name}_{viz_name}_{counter()}"
name = name.replace("/", "-")
save_path = save_dir / f"{name}.png"
cv2.imwrite(str(save_path), viz)
else:
cv2.imshow(f"{node_name}/{viz_name}", viz)

if not save_dir and window_closed(): # pragma: no cover
cv2.imshow(f"{node_name}/{viz_name}", viz)

if window_closed(): # pragma: no cover
broken = True
break

if save_dir is None: # pragma: no cover
with suppress(cv2.error): # type: ignore
cv2.destroyAllWindows()
with suppress(cv2.error): # pragma: no cover
cv2.destroyAllWindows()


class _VisualizationPredictionWriter(BasePredictionWriter):
def __init__(
self,
save_dir: Path,
img_paths: list[PathType] | None = None,
) -> None:
super().__init__(write_interval="batch")
self.save_dir = save_dir
self.img_paths = img_paths
self.counter = Counter()

def write_on_batch_end(
self,
trainer: Any,
pl_module: Any,
prediction: Any,
batch_indices: Any,
batch: Any,
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
del trainer, pl_module, batch_indices, batch, batch_idx, dataloader_idx

assert isinstance(prediction, LuxonisOutput)
renders = process_visualizations(prediction.visualizations)
_save_renders_batch(
renders,
self.save_dir,
self.counter,
self.img_paths,
)


@contextmanager
def _temporary_callback(
trainer: Any, callback: Any
) -> Generator[None, None, None]:
trainer.callbacks.append(callback)
try:
yield
finally:
with suppress(ValueError):
trainer.callbacks.remove(callback)


def _save_renders_batch(
renders: dict[tuple[str, str], list[np.ndarray]],
save_dir: Path,
counter: Counter,
img_paths: list[PathType] | None = None,
) -> None:
if not renders:
return

batch_size = len(next(iter(renders.values())))
for i in range(batch_size):
img_path: Path | None = None
if img_paths is not None:
img_path = Path(img_paths[counter()])
for (node_name, viz_name), visualizations in renders.items():
viz = visualizations[i]
if img_path is not None:
name = f"{img_path.stem}_{node_name}_{viz_name}"
else:
name = f"{node_name}_{viz_name}_{counter()}"
name = name.replace("/", "-")
cv2.imwrite(str(save_dir / f"{name}.png"), viz)


def create_loader_from_directory(
Expand Down
89 changes: 89 additions & 0 deletions tests/unittests/test_utils/test_infer_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from pathlib import Path
from types import SimpleNamespace
from typing import Any

import pytest
import torch
from luxonis_ml.typing import PathType

from luxonis_train.core.utils.infer_utils import infer_from_loader
from luxonis_train.lightning import LuxonisOutput


class _MockTrainer:
def __init__(self, prediction: LuxonisOutput, callbacks: list[object]):
self._prediction = prediction
self.callbacks = callbacks

def predict(
self,
lightning_module: object,
loader: list[object],
return_predictions: bool = True,
) -> list[LuxonisOutput] | None:
if return_predictions:
return [self._prediction]

for batch_idx, _ in enumerate(loader):
for callback in list(self.callbacks):
write_on_batch_end = getattr(
callback, "write_on_batch_end", None
)
if write_on_batch_end is None:
continue
write_on_batch_end(
self,
lightning_module,
self._prediction,
None,
None,
batch_idx,
0,
)
return None


def test_infer_from_loader_temporary_callback_does_not_leak(
monkeypatch: pytest.MonkeyPatch, tmp_path: Path
):
saved_paths: list[str] = []
monkeypatch.setattr(
"luxonis_train.core.utils.infer_utils.cv2.imwrite",
lambda path, _: saved_paths.append(Path(path).name) or True,
)

prediction = LuxonisOutput(
outputs={},
losses={},
visualizations={
"DiscSubNetHead": {
"SegmentationVisualizer": torch.zeros((1, 3, 4, 4))
}
},
)
existing_callback = object()
trainer = _MockTrainer(prediction, [existing_callback])
model: Any = SimpleNamespace(
pl_trainer=trainer,
lightning_module=object(),
)
loader: Any = [object()]
img_paths: list[PathType] = [tmp_path / "first.png"]

infer_from_loader(model, loader, tmp_path, img_paths)

assert trainer.callbacks == [existing_callback]
assert saved_paths == ["first_DiscSubNetHead_SegmentationVisualizer.png"]

trainer.predict(model.lightning_module, loader, return_predictions=False)

assert trainer.callbacks == [existing_callback]
assert saved_paths == ["first_DiscSubNetHead_SegmentationVisualizer.png"]

infer_from_loader(model, loader, tmp_path, img_paths)

assert trainer.callbacks == [existing_callback]
assert saved_paths == [
"first_DiscSubNetHead_SegmentationVisualizer.png",
"first_DiscSubNetHead_SegmentationVisualizer.png",
]
Loading