Skip to content

Feature(NG Dataset): Stratified Patching Strategy#710

Merged
jdeschamps merged 26 commits into
mainfrom
mc/feat/stratif-patching
Feb 4, 2026
Merged

Feature(NG Dataset): Stratified Patching Strategy#710
jdeschamps merged 26 commits into
mainfrom
mc/feat/stratif-patching

Conversation

@melisande-c
Copy link
Copy Markdown
Member

@melisande-c melisande-c commented Jan 30, 2026

Description

Note

tldr: Introducing a stratified patching strategy which has an additional feature that allows regions to be excluded from sampling.

Background - why do we need this PR?

There are two possible applications for this strategy:

  1. Validation and train splitting when there is only 1 or few images.
  2. A more efficient way to apply a background mask.

It also ensures a more consistent coverage of the data per epoch than the random patching.

Overview - what changed?

A new StratifiedPatchingStrategy class, added without integration into the CAREamics Dataset. Internally this class depends on a _ImageStratifiedPatching which in turn depends on a _SamplingRegion class.

As mentioned, there is also the feature to exclude patch regions from being sampled, and the number of patches is calculated to be ceil(prod(shape / patch_size)) - n_excluded_patches.

The StratifiedPatchingStrategy is constructed such that, the mean expected value of a pixel being selected per epoch is approximately equal to 1. The main difference with the random patching strategy, which has a mean expected value slightly greater than 1, is that the number of patches is calculated to be prod(ceil(shape / patch_size)) where the ceiling is taken before the product.

A demo notebook has also been included.

Implementation - how did you implement the changes?

Apologies it got a bit complicated but I will try my best to explain. I have also tried to include explanatory comments in the code.

Structure overview

Sampling regions which have an area of double the patch size are created so that they lie on a grid with a spacing equal to the patch size. This ensures that all possible patch coordinates can be selected. These sampling regions are represented by the _SamplingRegion class. The sampling region itself contains subregions, which make it easier to exclude patch regions.

For each sample in each image-stack an _ImageStratifiedPatching class is created, this stores _SamplingRegions in a dict[tuple[int, ...], _SamplingRegion] where the key of the dictionary represents the grid coordinate that the top-left corner of the sampling region lies on. Most of the patching logic happens in this class.

The StratifiedPatchingStrategy stores a list[list[_ImageStratifiedPatching]] where the elements of the outer list represent each image stack and there is an _ImageStratifiedPatching instance for each sample in each image stack.

Bin packing sampling regions

Not all the sampling regions will have the same area, unless the image is evenly divisible by the patch size and no patch regions have been excluded. To ensure that the smaller regions have an appropriate probability of being sampled from, they are combined together into bins. The bin packing is done using the best-fit-decreasing algorithm, courtesy of Claude.

Each bin is the same size and each dataset index corresponds to a bin. This means that each possible patch coordinate in each bin has the same probability of being selected.

An implementation detail is that a bin might not be filled up perfectly, and it is also possible to get empty bins, this extra probability is used to select all regions in the sample.

Sampling a patch coordinate

When a patch spec is sampled for a particular index:

  1. the StratifiedPatchingStrategy maps the index to the correct image stack and sample, similarly to how it is done in the random patching strategy.
  2. Then the index is given to the relevant _ImageStratifiedPatching; this index maps to a particular bin, which contains a set of _SamplingRegions.
  3. A _SamplingRegion is chosen from the bin with a probability equal to the ratio of the region's area to the bin size.
  4. A patch coordinate is sampled from the valid patch coordinates contained within the sampling region.

Excluding patch regions

The SamplingRegion contains sub regions, for the 2D case these are 4 regions illustrated by the diagram below

 ┌───┬──────────────────────────┐
 │   1                          │
 ├─1─┼─────(patch_size - 1)─────┤
 │   │                          │
 ⋮              ⋮
 └───┴──────────────────────────┘

If I want to exclude the bottom-left quadrant I exclude the large bottom left region, if I want to exclude the top-left quadrant I exclude both the bottom-left region and the top-left region.

To exclude a patch region from being sampled I have to exclude the correct quadrant from the 4 sampling regions that cover it.

Clipping sampling regions

Sampling regions can also be clipped, this is necessary at the edges of images which do not have an image size which is evenly divisible by the patch size. This is handled by the method _SamplingRegion.clip.

Changes Made

New features or files

  • All new classes have been added to src/careamics/dataset_ng/patching_strategies/stratified_patching.py
    • StratifiedPatchingStrategy
    • _ImageStratifiedPatching
    • _SamplingRegion

How has this been tested?

Stratified patching has been added to the test_all_patching_strategies file. This tests:

  • test_all_get_patch_spec
  • test_patches_cover_50percent
  • test_get_patch_indices

One new test has been added for the StratifiedPatchingStrategy that ensures that the excluded patches are never sampled.

Additional Notes and Examples

The included notebook produces the following figures and also demonstrates that mean expected value that a pixel is selected per epoch approximately 1.

59ea9e63-7749-4897-b86d-de30f6e5a0bf 45baf19c-c8db-4d55-8d5c-e2e69b8ea6f2 f9dd9718-e236-406c-bd4e-4427546e7842 41466591-24ef-4394-9c86-a7a0ff57bf52

Please ensure your PR meets the following requirements:

  • Code builds and passes tests locally, including doctests
  • New tests have been added (for bug fixes/features)
  • Pre-commit passes
  • PR to the documentation exists (for bug fixes / features)

Comment on lines +23 to +27
If the same index is used twice to sample a patch with the method `get_patch_spec`
there will be a high probability that it will come from the same sampling region,
but not necessarily 100%. Smaller sampling regions may be binned together into a
single index. The mean of all the expected values that each pixel will be selected
in a patch per epoch is 1.
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Depends what you mean, the strategy is seeded once at initialisation, if you sample the same index twice with the same instance you will get different patches. If you make a new instance with the same seed you will get the same patch

Comment thread src/careamics/dataset_ng/patching_strategies/stratified_patching.py
# Note: this is used by the FileIterSampler
def get_patch_indices(self, data_idx: int) -> Sequence[int]:
"""
Get the patch indices will return patches for a specific `image_stack`.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confusing sentence

Suggested change
Get the patch indices will return patches for a specific `image_stack`.
Return patches for a specific `image_stack`.

Comment thread src/careamics/dataset_ng/patching_strategies/stratified_patching.py Outdated
Comment on lines +178 to +184
patches_per_sample = [
[sample.n_patches for sample in image] for image in self.image_patching
]
patches_per_image = [sum(samples) for samples in patches_per_sample]
cumulative_image_patches = np.cumsum(patches_per_image)

start = 0 if data_idx == 0 else cumulative_image_patches[data_idx - 1]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't cumulative_image_patches already calculated and stored in self.cumulative_image_patches?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yep I will remove this calculation

Comment thread src/careamics/dataset_ng/patching_strategies/stratified_patching.py Outdated
if target_n_bins == 1:
return max_size, 1

# Edge case: if we want as many bins as values, bin size = max value
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also what is "values" in this sentence? number of sampling regions?

areas: dict[tuple[int, ...], int], target_n_bins: int
) -> tuple[int, int]:
"""Find the minimum bin size that will result in `target_n_bins` or less.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a description of the parameters?

areas are the areas of the sampling regions?

for grid_coord in grid_coords:
d: tuple[Literal[0, 1], ...] = (0, 1)
# exclude the patch from all the sampling regions that cover it
# These are the grid coords at: subtract (0, 0), (0, 1), (1, 0) (for 2D)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not (1,1)?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no because final quadrant is only defined by next sampling region, I do however iterate over (1, 1) in the loop below anyway

# These are the grid coords at: subtract (0, 0), (0, 1), (1, 0) (for 2D)
for d_idx in itertools.product(*[d for _ in range(self.ndims)]):
# q is the ID of the orthant to remove
q: tuple[Literal[0, 1], ...] = tuple(0 if i == 1 else 1 for i in d_idx)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also pretty hard to read, some examples in the comments would help

Comment on lines +91 to +92
Excluded patches must lie on a grid which starts at (0, 0) and has a spacing of
the given `patch_size`.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this raise an error?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be a good idea, currently if it is an invalid index it is ignored

@jdeschamps jdeschamps self-requested a review February 3, 2026 10:10
Comment thread src/careamics/dataset_ng/patching_strategies/stratified_patching.py Outdated
Co-authored-by: Joran Deschamps <6367888+jdeschamps@users.noreply.github.com>
Comment thread src/careamics/dataset_ng/patching_strategies/stratified_patching.py Outdated
Comment thread src/careamics/dataset_ng/patching_strategies/stratified_patching.py Outdated
@jdeschamps jdeschamps self-requested a review February 4, 2026 09:45
melisande-c and others added 2 commits February 4, 2026 10:48
Co-authored-by: Joran Deschamps <6367888+jdeschamps@users.noreply.github.com>
@jdeschamps jdeschamps merged commit b2e0379 into main Feb 4, 2026
13 checks passed
@jdeschamps jdeschamps deleted the mc/feat/stratif-patching branch February 4, 2026 10:16
melisande-c added a commit that referenced this pull request Feb 12, 2026
… dataset (#745)

## Description

> [!NOTE]  
> **tldr**: This PR makes it possible to use the new stratified patching
strategy in the next-generation dataset.

### Background - why do we need this PR?

A stratified patching strategy was introduced in #710 with 2
applications in mind: validation data splitting, and more efficient
masking. This PR allows it to be used by the `CAREamicsDataset`.

### Overview - what changed?

A new `StratifiedPatchingConfig` pydantic class, and some minor changes
to the `NGDataConfig`.

### Implementation - how did you implement the changes?

Simply added a new case to the match-case block in the
`create_patching_strategy` factory. And in the `NGDataConfig` validation
"stratified" is a valid patching strategy choice for validation


## Changes Made

### New features or files

- `StratifiedPatchingConfig`

### Modified features or files

- `NGDataConfig`
- `create_patching_strategy`

## How has this been tested?

Currently working on notebooks to benchmark. Denoising results so far
look good but stratified patching seems to be a little slower. I will
include results here shortly.

EDIT: performance benchmarking

Could be a bit more rigorous but the set-up is:
- Machine: M3 Mac with 16GB Memory:
- Data: SEM (long scan time as GT for PSNR)
- Training: 10 epochs (5 repeats)

Full train time:
- Stratified: 40.9 ± 0.9 s
- Random: 40.0 ± 1.4 s

PSNR:
- Stratified: 18.9 ± 0.2
- Random: 18.5 ± 0.4

## Related Issues

- Resolves #715 


## Additional Notes and Examples

I haven't changed any of the convenience functions, so currently the
easiest way to test with the new configuration is to do something like:

```python
from careamics.config.ng_factories import create_n2v_configuration
from careamics.config.data.patching_strategies import StratifiedPatchingConfig

config = create_n2v_configuration(
    experiment_name="example",
    data_type="array",
    axes="YX",
    patch_size=(64, 64),
    batch_size=32,
    num_epochs=50,
    use_n2v2=False,
    train_dataloader_params={"num_workers": 0}
)
config.data_config.patching = StratifiedPatchingConfig(patch_size=(64, 64))
```

---

**Please ensure your PR meets the following requirements:**

- [x] Code builds and passes tests locally, including doctests
- [ ] New tests have been added (for bug fixes/features)
- [x] Pre-commit passes
- [ ] PR to the documentation exists (for bug fixes / features)

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Joran Deschamps <6367888+jdeschamps@users.noreply.github.com>
melisande-c added a commit that referenced this pull request Feb 23, 2026
)

## Description

> [!NOTE]  
> **tldr**: A basic implementation of validation splitting using the new
stratified patching strategy introduced in #710. More steps have to be
taken to integrate it into the datamodule and dataset.


### Background - why do we need this PR?

It is often more convenient for users if a portion of patches are kept
aside for validation rather than having to provide their own validation
data, which often has to be in a separate file. Patches used for
validation should not overlap with the training patches. This is why we
could not use the fully random patching strategy.

### Overview - what changed?

Added a `create_val_split` function that returns patching strategies for
training and validation that are created so that their patches will
never overlap.

A new `FixedPatchingStrategy` that always returns a chosen set of
patches has been added for the validation patching strategy.

### Implementation - how did you implement the changes?

The `StratifiedPatchingStrategy` was created so that certain patches
could be excluded from sampling. All the `create_val_split` function has
to do is randomly select the validation patches and exclude them from
the stratified patching strategy used for training. Then a fixed
patching strategy can be created for validation with the patches that
were excluded from training.

Currently the validation patches are chosen completely randomly with
every patch having the same probability of being chosen, but we can
maybe discuss different sampling methods in the future. If validation
patches are selected too close together they reduce the probability of
surrounding patches being selected for training each epoch which may not
be optimal.

## Changes Made

### New features or files

- `create_val_split` function
- `FixedPatchingStrategy`
- `get_included_grid_coords` method in `StratifiedPatchingStrategy`
  -> this makes selecting validation patches easier.


## How has this been tested?

Added a test `test_train_val_complementary` which makes sure that the
validation and training patches do not overlap.

`FixedRandomPatching` strategy has been add to the `test_all_strategies`
tests.

## Related Issues

After fully integrating this into the data module (with future PRs) #416
will be resolved.

## Additional Notes and Examples

See the included demo notebook that produces this figure:

<img width="1211" height="1207"
alt="7e409bc4-ae81-4fe2-a303-9a4378545578"
src="https://github.com/user-attachments/assets/237eb62a-9a55-4e20-97d8-1ce1ccff50ac"
/>

### Future steps

To integrate this feature we will have to change the initialisation of
the dataset so that it does not create the patching strategies itself
but takes them as arguments.

---

**Please ensure your PR meets the following requirements:**

- [x] Code builds and passes tests locally, including doctests
- [x] New tests have been added (for bug fixes/features)
- [x] Pre-commit passes
- [ ] PR to the documentation exists (for bug fixes / features)

---------

Co-authored-by: Joran Deschamps <6367888+jdeschamps@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
jdeschamps added a commit that referenced this pull request Mar 18, 2026
## Disclaimer

<!-- Please disclose your use of AI by checking the correct checkbox. If
you are
an AI agent implementing this PR, please check "I am an AI agent".-->
- [ ] I am an AI agent.
- [ ] I have used AI and I thoroughly reviewed every line.
- [x] I have not used AI extensively.


## Description

> [!NOTE]  
> **tldr**: This PR improves the performance of the stratified patching
strategy that was introduced in #710.

### Background - why do we need this PR?

The first iteration of the patching strategy was not optimised.

### Overview - what changed?

All changes are contained within
`careamics/dataset_ng/patching_strategies/stratified_patching.py`

### Implementation - how did you implement the changes?

Time savings on the initialisation and patch exclusion were found by
improving the `_region_bin_packing` (which Claude originally wrote).
This bin packing is a slight variation on the normal bin packing
problem. We want to find an efficient bin packing but the bin capacity
is allowed to expand (although we want to keep this to a minimum),
instead of the usual problem where the bin capacity is fixed but more
bins can be added. I think this is what confused Claude. Basically I had
to remove a stupid binary search on the bin size and just increase the
capacity as necessary when none of the bins could fit the next sampling
region.

Time savings for the generation of patch specs were made by refactoring
the way some values are stored. The main one being the region bins now
contain indices instead of dictionary keys and the probabilities are
stored as numpy arrays that can be directly accessed.

## Changes Made

### New features or files

<!-- List new features or files added. -->
- Notebook to demonstrate patching strategy performance
  - `src/careamics/dataset_ng/demos/stratified_patching_perf.ipynb`

### Modified features or files

<!-- List important modified features or files. -->
- `careamics/dataset_ng/patching_strategies/stratified_patching.py`

## How has this been tested?

- All tests pass
- additional test added for the bin packing algorithm 
- Performance tests using included notebook

## Additional Notes and Examples

### Results from notebook

#### Initialisation

The results below show that the random patching strategy has almost no
initialisation overhead and performs with constant time relative to the
number of patches.

In comparison the stratified patching strategy has some overhead for
initialisation with what seems to be linear time. However for almost
10,000 patches this about 0.35s.

(Patches were increased by increasing the number of samples or by
increasing the image size to see if it had any effect.)

<img width="576" height="432" alt="23b837ac-aba0-4689-83ae-ca3e1eb3733a"
src="https://github.com/user-attachments/assets/3b026761-6a19-426e-99aa-365b25529a18"
/>

#### Patch specs generation

The results show that the stratified patching strategy is about 3x
slower than the random patching strategy at generating patch specs. For
30,000 patches (roughly 30 epochs) the stratified patching will add
about 1s of overhead whereas the random patching will add about 0.3s of
overhead.

For 200 epochs of ~1000 patches the random patching strategy added an
overhead of ~2.2s whereas the stratified patching strategy added an
overhead of ~6.0s.

#### Patch distribution
<img width="611" height="808" alt="dd49cba7-0ae6-478e-80c6-3aed9173585f"
src="https://github.com/user-attachments/assets/22789309-d5f0-4c7d-9609-d6f3e213b8a8"
/>

**Tested on**: MacBook Air, M3, 2024, 16GB Memory.

---

**Please ensure your PR meets the following requirements:**

- [x] Code builds and passes tests locally, including doctests
- [x] New tests have been added (for bug fixes/features)
- [x] Pre-commit passes
- [ ] PR to the documentation exists (for bug fixes / features)

---------

Co-authored-by: Joran Deschamps <6367888+jdeschamps@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

No open projects
Status: Done

Development

Successfully merging this pull request may close these issues.

2 participants