Skip to content

fix:prevent file overwriting when S_dim > batch_size in WriteImage#726

Merged
jdeschamps merged 2 commits into
mainfrom
ds/NGD-sdim-predictionwriter
Feb 4, 2026
Merged

fix:prevent file overwriting when S_dim > batch_size in WriteImage#726
jdeschamps merged 2 commits into
mainfrom
ds/NGD-sdim-predictionwriter

Conversation

@diyasrivas
Copy link
Copy Markdown
Member

Description

Note

tldr:
Fixed file overwriting bug in WriteImage strategy when S dimension exceeds batch size by implementing cross-batch caching.

Background - why do we need this PR?

The WriteImage prediction writer strategy processes predictions batch-by-batch without maintaining state across batches. When the S dimension (sample count) exceeds the batch size, incomplete predictions are written and subsequently overwritten by later batches. In non-tiled prediction scenarios where S_dim > batch_size, this causes silent data loss.

Overview - what changed?

Implemented caching mechanism in WriteImage to accumulate predictions across batches until all S samples for a given data_idx are collected. So now the code does the following:

  1. Caches predictions by data_idx across multiple write_batch() calls
  2. Extracts expected sample count from data_shape and axes metadata
  3. Writes combined output only when cache contains all expected S samples
  4. Clears cache after successful write to prevent memory accumulation

Implementation - how did you implement the changes?

This mechanism is inspired by solution for CachedTiles in #702 :

def __init__(self, ...):
    self.image_cache: dict[int, list[ImageRegionData]] = defaultdict(list)

to manage the caching, and:

def write_batch(self, dirpath, predictions):
    for pred in predictions:
        self.image_cache[pred.region_spec["data_idx"]].append(pred)
    self._write_complete_images(dirpath)

Modified write_batch() to cache instead of immediately writing. Also added:

  • _get_total_samples(): Extracts expected S dimension from data_shape[axes.index("S")]
  • _get_complete_images(): Identifies data_idx where len(cache) == total_samples
  • _write_complete_images(): Executes combine_samples() and write operations for complete images

Changes Made

In src/careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py:

  • Added image_cache attribute for cross-batch management
  • Rewrote write_batch() to cache predictions instead of immediate writing
  • Added _get_total_samples() to extract S dimension
  • Added _get_complete_images() to detect completion
  • Added _write_complete_images() to write the complete images and then clear the cache

In tests/lightning/dataset_ng/test_smoke_prediction.py:

  • Uncommented the 4 test cases that were failing before

How has this been tested?

Uncommented 4 previously failing test cases in test_smoke_prediction.py:

  • Line 37: ((5, 16, 32, 32), "SZYX", None, False) - S=5 with batch_size=4
  • Line 45: ((5, 3, 16, 32, 32), "SCZYX", None, False) - S=5, C=3 with batch_size=4
  • Line 47: ((5, 3, 16, 32, 32), "SCZYX", [1], False) - S=5 with channel subsetting
  • Line 49: ((5, 3, 16, 32, 32), "SCZYX", [0, 2], False) - S=5 with channel subsetting

Test results: All 36 tests passed (including the 4 previously failing cases)

tests/lightning/dataset_ng/test_smoke_prediction.py::test_smoke_n2v_tiff[shape11-SZYX-None-False] PASSED
tests/lightning/dataset_ng/test_smoke_prediction.py::test_smoke_n2v_tiff[shape19-SCZYX-None-False] PASSED
tests/lightning/dataset_ng/test_smoke_prediction.py::test_smoke_n2v_tiff[shape21-SCZYX-channels21-False] PASSED
tests/lightning/dataset_ng/test_smoke_prediction.py::test_smoke_n2v_tiff[shape23-SCZYX-channels23-False] PASSED

======================== 36 passed, 153 warnings in 73.33s ========================

Related Issues

Breaking changes

None

Additional Notes and Examples

None

Please ensure your PR meets the following requirements:

  • Code builds and passes tests locally, including doctests
  • New tests have been added (for bug fixes/features)
  • Pre-commit passes
  • PR to the documentation exists (for bug fixes / features)

@diyasrivas diyasrivas requested a review from a team February 4, 2026 08:18
@jdeschamps jdeschamps self-assigned this Feb 4, 2026
@jdeschamps jdeschamps self-requested a review February 4, 2026 09:22
@jdeschamps jdeschamps merged commit 63118e4 into main Feb 4, 2026
12 checks passed
@jdeschamps jdeschamps deleted the ds/NGD-sdim-predictionwriter branch February 4, 2026 09:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

No open projects
Status: Done

Development

Successfully merging this pull request may close these issues.

Bug: prediction writer callback for whole images breaks when S_dim > batch_size

2 participants