fix:prevent file overwriting when S_dim > batch_size in WriteImage#726
Merged
Conversation
jdeschamps
approved these changes
Feb 4, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Note
tldr:
Fixed file overwriting bug in
WriteImagestrategy when S dimension exceeds batch size by implementing cross-batch caching.Background - why do we need this PR?
The
WriteImageprediction writer strategy processes predictions batch-by-batch without maintaining state across batches. When the S dimension (sample count) exceeds the batch size, incomplete predictions are written and subsequently overwritten by later batches. In non-tiled prediction scenarios whereS_dim > batch_size, this causes silent data loss.Overview - what changed?
Implemented caching mechanism in
WriteImageto accumulate predictions across batches until all S samples for a givendata_idxare collected. So now the code does the following:data_idxacross multiplewrite_batch()callsdata_shapeandaxesmetadataImplementation - how did you implement the changes?
This mechanism is inspired by solution for
CachedTilesin #702 :to manage the caching, and:
Modified
write_batch()to cache instead of immediately writing. Also added:_get_total_samples(): Extracts expected S dimension fromdata_shape[axes.index("S")]_get_complete_images(): Identifiesdata_idxwherelen(cache) == total_samples_write_complete_images(): Executescombine_samples()and write operations for complete imagesChanges Made
In
src/careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py:image_cacheattribute for cross-batch managementwrite_batch()to cache predictions instead of immediate writing_get_total_samples()to extract S dimension_get_complete_images()to detect completion_write_complete_images()to write the complete images and then clear the cacheIn
tests/lightning/dataset_ng/test_smoke_prediction.py:How has this been tested?
Uncommented 4 previously failing test cases in
test_smoke_prediction.py:((5, 16, 32, 32), "SZYX", None, False)- S=5 with batch_size=4((5, 3, 16, 32, 32), "SCZYX", None, False)- S=5, C=3 with batch_size=4((5, 3, 16, 32, 32), "SCZYX", [1], False)- S=5 with channel subsetting((5, 3, 16, 32, 32), "SCZYX", [0, 2], False)- S=5 with channel subsettingTest results: All 36 tests passed (including the 4 previously failing cases)
Related Issues
S_dim > batch_size#660Breaking changes
None
Additional Notes and Examples
None
Please ensure your PR meets the following requirements: