Skip to content

Commit 8a67cf6

Browse files
Feature(Next-Gen): Patch filtering rework (#832)
## Disclaimer - [ ] I am an AI agent. - [ ] I have used AI and I thoroughly reviewed every line. - [x] I have not used AI extensively. ## Description > [!NOTE] > **tldr**: This makes the patch filtering in the NG-Dataset more efficient by using the `StratifiedPatchingStrategy` to set different sampling regions in the data to have a reduced probability of being sampled per epoch. ### Background - why do we need this PR? The old patch filtering implementation used the `RandomPatchingStrategy` to keep resampling patches until the patch has signal. If there is a large amount of background in the data this is inefficient. In the new implementation the filtering process is happens once at the start of training and there is no need to define a "patience". ### Overview - what changed? Quite a lot... - The attributes `patch_filter` and `coord_filter` have been removed from the `CAREamicsDataset`, and it is also no longer initialised with masks. This also means all functionality for filtering within the dataset has been removed. - Instead, the filtering process happens in `/ng_dataset/filter_bg.py`, there are two functions, one for filtering with a filter, one for filtering with a mask. - The `create_train_val_datasets` and `create_val_split_dataset` factory functions have been refactored slightly to reduced code duplication. The `train_dataset` is first created with the new `create_trian_dataset` within which the data is filtered. - The `StratifiedPatchingStrategy` has been modified to have a new method to set the probability that a sampling region will be sampled from during an epoch. This also necessitates that the number of patches are reduced if the probability of any regions is reduced. - `filtered_patch_prob` and `filter_ref_channel` has been added to the ng data config. - The ref channel was added for the case of multiple channels. ### Implementation - how did you implement the changes? - The idea is a filtering function can be applied to each sampling region in the `StratifiedPatchingStrategy`, if the region doesn't pass some threshold the probability of it being sampled from in an epoch can be reduced. - The patching strategy instance only needs to store the probability of each region which doesn't take too much memory. - When calculating the total probabilities and the sampling bins, the adjusted areas of the regions can be used, this is the area multiplied by the new probability. ## Changes Made ### New features or files - `src/careamics/dataset_ng/filter_bg.py` - `filter_background` - `filter_background_with_mask` ### Modified features or files - `src/careamics/dataset_ng/factory.py` - src/careamics/dataset_ng/patching_strategies/stratified_patching.py ### Removed features or files - filtering from `CAREamicsDataset` ## How has this been tested? - Added unit tests that test the number of patches are reduced in the `StratifiedPatchingStrategy` if the method `set_region_probs` is called. - Added functional tests for filtering background, from a few different layers: - The patching strategy layer - The `filter_background` and `filter_background_with_mask` layer - The `create_train_dataset` layer Still missing is a test for the lightning module layer or an e2e test with filtering. ## Related Issues - Resolves #741 ## Breaking changes - Data config has changed, `filter_patience` has been removed - Passing masks to the dataset no longer happens ## Additional Notes and Examples ### Still TODO: - Clean up patch filtering configs and filtering classes, we no longer need filtering patience or the probability. - I am considering removing the option to have different "coord filters" (we don't have any others) because a mask filter config is required to pass masks to `create_train_dataset`. --- **Please ensure your PR meets the following requirements:** - [x] Code builds and passes tests locally, including doctests - [x] New tests have been added (for bug fixes/features) - [x] Pre-commit passes - [ ] PR to the documentation exists (for bug fixes / features) --------- Co-authored-by: Joran Deschamps <6367888+jdeschamps@users.noreply.github.com>
1 parent 2770149 commit 8a67cf6

29 files changed

Lines changed: 951 additions & 437 deletions

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ repos:
4444
rev: v1.10.0
4545
hooks:
4646
- id: numpydoc-validation
47-
exclude: "^docs/.*|^src/careamics/lvae_training/.*|^src/careamics/models/lvae/.*|^src/careamics/losses/lvae/.*|^scripts/.*|^src/careamics/careamist_v2.py|^src/careamics/lightning/dataset_ng/data_module.py"
47+
exclude: "^tests/.*|^docs/.*|^src/careamics/lvae_training/.*|^src/careamics/models/lvae/.*|^src/careamics/losses/lvae/.*|^scripts/.*|^src/careamics/careamist_v2.py|^src/careamics/lightning/dataset_ng/data_module.py"
4848
# data_module: overloads and config params (val_percentage etc.) don't match signature; see https://github.com/numpy/numpydoc/issues/559
4949

5050
# # jupyter linting and formatting

src/careamics/config/data/ng_data_config.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -337,9 +337,15 @@ class NGDataConfig(BaseModel):
337337
"""Coordinate filter to apply when using random patching. Only available if
338338
mode is `training`."""
339339

340-
patch_filter_patience: int = Field(default=5, ge=1)
341-
"""Number of consecutive patches not passing the filter before accepting the next
342-
patch."""
340+
filtered_patch_prob: float = Field(0.1, ge=0.0, le=1.0)
341+
"""The probability that each patch considered background will be selected each epoch
342+
during training. Patches can be considered background by either using a
343+
`patch_filter` or by supplying a mask during training. If neither is chosen this
344+
parameter is ignored."""
345+
346+
# TODO: Move inside patch_filter
347+
filter_ref_channel: int = 0
348+
"""The channel to use as reference for filtering."""
343349

344350
augmentations: Sequence[Union[XYFlipConfig, XYRandomRotate90Config]] = Field(
345351
default=(
@@ -1062,6 +1068,8 @@ def convert_mode(
10621068
if new_mode == Mode.PREDICTING and new_dataloader_params is not None
10631069
else self.pred_dataloader_params
10641070
),
1071+
"patch_filter": None,
1072+
"coord_filter": None,
10651073
}
10661074
)
10671075

src/careamics/config/data/patch_filter/mask_filter_config.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,8 @@ class MaskFilterConfig(FilterConfig):
1313
name: Literal["mask"] = "mask"
1414
"""Name of the filter."""
1515

16-
coverage: float = Field(0.5, ge=0.0, le=1.0)
17-
"""Percentage of masked pixels required to keep a patch."""
16+
coverage: float | None = Field(None, ge=0.0, le=1.0)
17+
"""Minimum ratio of masked pixels required to keep a sampling region,
18+
`default=None`. If `None` then `1/(ndims**2)` is used where `ndims` is the number of
19+
spatial dimensions.
20+
"""

src/careamics/dataset_ng/dataset.py

Lines changed: 2 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from .normalization import create_normalization
1616
from .normalization.statistics import resolve_normalization_config
1717
from .patch_extractor import PatchExtractor
18-
from .patch_filter import create_coord_filter, create_patch_filter
1918
from .patching_strategies import (
2019
PatchingStrategy,
2120
PatchSpecs,
@@ -150,8 +149,6 @@ class CareamicsDataset(Dataset, Generic[GenericImageStack]):
150149
Extractor for input patches.
151150
target_extractor : PatchExtractor or None, optional
152151
Extractor for target patches.
153-
mask_extractor : PatchExtractor or None, optional
154-
Extractor for mask (e.g. for coord filter).
155152
"""
156153

157154
def __init__(
@@ -160,7 +157,6 @@ def __init__(
160157
patching_strategy: PatchingStrategy,
161158
input_extractor: PatchExtractor[GenericImageStack],
162159
target_extractor: PatchExtractor[GenericImageStack] | None = None,
163-
mask_extractor: PatchExtractor[GenericImageStack] | None = None,
164160
) -> None:
165161
"""Contructor.
166162
@@ -174,8 +170,6 @@ def __init__(
174170
Extractor for input patches.
175171
target_extractor : PatchExtractor or None, optional
176172
Extractor for target patches.
177-
mask_extractor : PatchExtractor or None, optional
178-
Extractor for mask (e.g. for coord filter).
179173
"""
180174
# Make sure all the image sizes are greater than the patch size for training
181175
data_shapes = [
@@ -197,18 +191,6 @@ def __init__(
197191
self.input_extractor = input_extractor
198192
self.target_extractor = target_extractor
199193

200-
self.patch_filter = (
201-
create_patch_filter(self.config.patch_filter)
202-
if self.config.patch_filter is not None
203-
else None
204-
)
205-
self.coord_filter = (
206-
create_coord_filter(self.config.coord_filter, mask=mask_extractor)
207-
if self.config.coord_filter is not None and mask_extractor is not None
208-
else None
209-
)
210-
self.patch_filter_patience = self.config.patch_filter_patience
211-
212194
self.patching_strategy = patching_strategy
213195

214196
resolve_normalization_config(
@@ -343,54 +325,6 @@ def _extract_patches(
343325
)
344326
return input_patch, target_patch
345327

346-
def _get_filtered_patch(
347-
self, index: int
348-
) -> tuple[NDArray[Any], NDArray[Any] | None, PatchSpecs]:
349-
"""Extract a patch using filtering.
350-
351-
Parameters
352-
----------
353-
index : int
354-
Dataset index used to obtain the patch spec.
355-
356-
Returns
357-
-------
358-
tuple of (NDArray, NDArray or None, PatchSpecs)
359-
Input patch, optional target patch, and patch spec.
360-
"""
361-
should_filter = self.config.mode == Mode.TRAINING and (
362-
self.patch_filter is not None or self.coord_filter is not None
363-
)
364-
empty_patch = True
365-
patch_filter_patience = self.patch_filter_patience # reset patience
366-
367-
while empty_patch and patch_filter_patience > 0:
368-
# query patches
369-
patch_spec = self.patching_strategy.get_patch_spec(index)
370-
371-
# filter patch based on coordinates if needed
372-
if should_filter and self.coord_filter is not None:
373-
if self.coord_filter.filter_out(patch_spec):
374-
patch_filter_patience -= 1
375-
376-
# TODO should we raise an error rather than silently accept patches?
377-
# if patience runs out without ever finding coordinates
378-
# then we need to guard against an exist before defining
379-
# input_patch and target_patch
380-
if patch_filter_patience != 0:
381-
continue
382-
383-
input_patch, target_patch = self._extract_patches(patch_spec)
384-
385-
# filter patch based on values if needed
386-
if should_filter and self.patch_filter is not None:
387-
empty_patch = self.patch_filter.filter_out(input_patch)
388-
patch_filter_patience -= 1 # decrease patience
389-
else:
390-
empty_patch = False
391-
392-
return input_patch, target_patch, patch_spec
393-
394328
def __getitem__(
395329
self, index: int
396330
) -> Union[tuple[ImageRegionData], tuple[ImageRegionData, ImageRegionData]]:
@@ -406,7 +340,8 @@ def __getitem__(
406340
tuple of ImageRegionData
407341
(input_data,) or (input_data, target_data).
408342
"""
409-
input_patch, target_patch, patch_spec = self._get_filtered_patch(index)
343+
patch_spec = self.patching_strategy.get_patch_spec(index)
344+
input_patch, target_patch = self._extract_patches(patch_spec)
410345

411346
# apply normalization
412347
input_patch, target_patch = self.normalization(input_patch, target_patch)

0 commit comments

Comments
 (0)