Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""A strategy writing whole images directly."""

from collections import defaultdict
from pathlib import Path
from typing import Any

Expand All @@ -13,13 +14,13 @@
from .write_strategy import WriteStrategy


# TODO bug: batch is over samples for whole images, if one batch does not cover
# all samples, it will write an incomplete image, then overwrite it whith the next
# batch
class WriteImage(WriteStrategy):
"""
A strategy for writing image predictions (i.e. un-tiled predictions).

Predictions are cached until all samples for a given data_idx are collected,
then combined and written. This prevents overwrites when S_dim > batch_size.

Parameters
----------
write_func : WriteFunc
Expand All @@ -37,6 +38,8 @@ class WriteImage(WriteStrategy):
Extension added to prediction file paths.
write_func_kwargs : dict of {str: Any}
Extra kwargs to pass to `write_func`.
image_cache : dict of {int: list of ImageRegionData}
Cache for predictions across batches, keyed by data_idx.
"""

def __init__(
Expand All @@ -63,13 +66,18 @@ def __init__(
self.write_extension: str = write_extension
self.write_func_kwargs: dict[str, Any] = write_func_kwargs

self.image_cache: dict[int, list[ImageRegionData]] = defaultdict(list)

def write_batch(
self,
dirpath: Path,
predictions: list[ImageRegionData],
) -> None:
"""
Save full images.
Cache predictions and save full images.

Predictions are cached by data_idx until all samples (S dimension)
are collected, then combined and written.

Parameters
----------
Expand All @@ -80,22 +88,84 @@ def write_batch(
"""
assert predictions is not None

image_lst, sources = combine_samples(predictions)

for i, image in enumerate(image_lst):
source_path = Path(sources[i])

# Handle array sources by using data_idx from predictions
postfix = ""
if source_path.stem == "array":
# Get data_idx from the corresponding prediction
data_idx = predictions[i].region_spec["data_idx"]
postfix = f"_{data_idx}"

file_path = create_write_file_path(
dirpath=dirpath,
file_path=source_path,
write_extension=self.write_extension,
postfix=postfix,
)
self.write_func(file_path=file_path, img=image, **self.write_func_kwargs)
for pred in predictions:
data_idx = pred.region_spec["data_idx"]
self.image_cache[data_idx].append(pred)

self._write_complete_images(dirpath)

def _get_total_samples(self, prediction: ImageRegionData) -> int:
"""
Get the expected total number of samples from data_shape and axes.

Parameters
----------
prediction : ImageRegionData
A prediction containing metadata about the original data.

Returns
-------
int
Total number of samples in the S dimension, or 1 if no S dimension.
"""
if "S" in prediction.axes:
s_idx = prediction.axes.index("S")
return prediction.data_shape[s_idx]
return 1

def _get_complete_images(self) -> list[int]:
"""
Get data indices where all samples have been collected.

Returns
-------
list of int
Data indices of complete images in the cache.
"""
complete_images = []
for data_idx in self.image_cache.keys():
total_samples = self._get_total_samples(self.image_cache[data_idx][0])

if len(self.image_cache[data_idx]) == total_samples:
complete_images.append(data_idx)
elif len(self.image_cache[data_idx]) > total_samples:
raise ValueError(
f"More samples cached for data_idx {data_idx} than expected. "
f"Expected {total_samples}, found "
f"{len(self.image_cache[data_idx])}."
)

return complete_images

def _write_complete_images(self, dirpath: Path) -> None:
"""
Write complete images from cache and clear them.

Parameters
----------
dirpath : Path
Path to directory to save predictions to.
"""
complete_images = self._get_complete_images()

for data_idx in complete_images:
cached_preds = self.image_cache.pop(data_idx)

image_lst, sources = combine_samples(cached_preds)

for i, image in enumerate(image_lst):
source_path = Path(sources[i])

postfix = ""
if source_path.stem == "array":
postfix = f"_{data_idx}"

file_path = create_write_file_path(
dirpath=dirpath,
file_path=source_path,
write_extension=self.write_extension,
postfix=postfix,
)
self.write_func(
file_path=file_path, img=image, **self.write_func_kwargs
)
8 changes: 4 additions & 4 deletions tests/lightning/dataset_ng/test_smoke_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,19 @@
((16, 32, 32), "ZYX", None, True),
((16, 32, 32), "ZYX", None, False),
((5, 16, 32, 32), "SZYX", None, True),
# ((5, 16, 32, 32), "SZYX", None, False), # TODO: fails until #660 is fixed
((5, 16, 32, 32), "SZYX", None, False),
((3, 16, 32, 32), "CZYX", None, True),
((3, 16, 32, 32), "CZYX", None, False),
((3, 16, 32, 32), "CZYX", [1], True),
((3, 16, 32, 32), "CZYX", [1], False),
((3, 16, 32, 32), "CZYX", [0, 2], True),
((3, 16, 32, 32), "CZYX", [0, 2], False),
((5, 3, 16, 32, 32), "SCZYX", None, True),
# ((5, 3, 16, 32, 32), "SCZYX", None, False), # TODO: fails until #660 is fixed
((5, 3, 16, 32, 32), "SCZYX", None, False),
((5, 3, 16, 32, 32), "SCZYX", [1], True),
# ((5, 3, 16, 32, 32), "SCZYX", [1], False), # TODO: fails until #660 is fixed
((5, 3, 16, 32, 32), "SCZYX", [1], False),
((5, 3, 16, 32, 32), "SCZYX", [0, 2], True),
# ((5, 3, 16, 32, 32), "SCZYX", [0, 2], False) # TODO: fails until #660 is fixed
((5, 3, 16, 32, 32), "SCZYX", [0, 2], False),
],
)
def test_smoke_n2v_tiff(tmp_path, shape, axes, channels, tiled):
Expand Down