Skip to content

Commit 02dd726

Browse files
melisande-cjdeschampspre-commit-ci[bot]
authored
Feat(Next-Gen Dataset): Validation split function (not integrated) (#771)
## Description > [!NOTE] > **tldr**: A basic implementation of validation splitting using the new stratified patching strategy introduced in #710. More steps have to be taken to integrate it into the datamodule and dataset. ### Background - why do we need this PR? It is often more convenient for users if a portion of patches are kept aside for validation rather than having to provide their own validation data, which often has to be in a separate file. Patches used for validation should not overlap with the training patches. This is why we could not use the fully random patching strategy. ### Overview - what changed? Added a `create_val_split` function that returns patching strategies for training and validation that are created so that their patches will never overlap. A new `FixedPatchingStrategy` that always returns a chosen set of patches has been added for the validation patching strategy. ### Implementation - how did you implement the changes? The `StratifiedPatchingStrategy` was created so that certain patches could be excluded from sampling. All the `create_val_split` function has to do is randomly select the validation patches and exclude them from the stratified patching strategy used for training. Then a fixed patching strategy can be created for validation with the patches that were excluded from training. Currently the validation patches are chosen completely randomly with every patch having the same probability of being chosen, but we can maybe discuss different sampling methods in the future. If validation patches are selected too close together they reduce the probability of surrounding patches being selected for training each epoch which may not be optimal. ## Changes Made ### New features or files - `create_val_split` function - `FixedPatchingStrategy` - `get_included_grid_coords` method in `StratifiedPatchingStrategy` -> this makes selecting validation patches easier. ## How has this been tested? Added a test `test_train_val_complementary` which makes sure that the validation and training patches do not overlap. `FixedRandomPatching` strategy has been add to the `test_all_strategies` tests. ## Related Issues After fully integrating this into the data module (with future PRs) #416 will be resolved. ## Additional Notes and Examples See the included demo notebook that produces this figure: <img width="1211" height="1207" alt="7e409bc4-ae81-4fe2-a303-9a4378545578" src="https://github.com/user-attachments/assets/237eb62a-9a55-4e20-97d8-1ce1ccff50ac" /> ### Future steps To integrate this feature we will have to change the initialisation of the dataset so that it does not create the patching strategies itself but takes them as arguments. --- **Please ensure your PR meets the following requirements:** - [x] Code builds and passes tests locally, including doctests - [x] New tests have been added (for bug fixes/features) - [x] Pre-commit passes - [ ] PR to the documentation exists (for bug fixes / features) --------- Co-authored-by: Joran Deschamps <6367888+jdeschamps@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent c8a1b35 commit 02dd726

9 files changed

Lines changed: 478 additions & 4 deletions

File tree

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"id": "0",
7+
"metadata": {},
8+
"outputs": [],
9+
"source": [
10+
"from collections.abc import Sequence\n",
11+
"\n",
12+
"import matplotlib.pyplot as plt\n",
13+
"import numpy as np\n",
14+
"from numpy.typing import NDArray\n",
15+
"\n",
16+
"from careamics.dataset_ng.patching_strategies import (\n",
17+
" PatchingStrategy,\n",
18+
" StratifiedPatchingStrategy,\n",
19+
")\n",
20+
"from careamics.dataset_ng.val_split import create_val_split"
21+
]
22+
},
23+
{
24+
"cell_type": "code",
25+
"execution_count": null,
26+
"id": "1",
27+
"metadata": {},
28+
"outputs": [],
29+
"source": [
30+
"def demo_selected_patches(\n",
31+
" patching_strategy: PatchingStrategy,\n",
32+
" data_shapes: Sequence[Sequence[int]],\n",
33+
" epochs: int,\n",
34+
") -> Sequence[NDArray[np.int_]]:\n",
35+
" \"\"\"Create a map where all the patches have been selected from.\n",
36+
"\n",
37+
" Every time a patch is selected that area is incremented by 1.\n",
38+
" \"\"\"\n",
39+
" tracking_arrays = [np.zeros(shape, dtype=int) for shape in data_shapes]\n",
40+
" for _ in range(epochs):\n",
41+
" for index in range(patching_strategy.n_patches):\n",
42+
" patch_spec = patching_strategy.get_patch_spec(index)\n",
43+
" data_idx = patch_spec[\"data_idx\"]\n",
44+
" sample_idx = patch_spec[\"sample_idx\"]\n",
45+
" coord = patch_spec[\"coords\"]\n",
46+
" patch_size = patch_spec[\"patch_size\"]\n",
47+
"\n",
48+
" patch_slice = [\n",
49+
" slice(c, c + ps) for c, ps in zip(coord, patch_size, strict=True)\n",
50+
" ]\n",
51+
" tracking_arrays[data_idx][sample_idx, ..., *patch_slice] += 1\n",
52+
" return tracking_arrays"
53+
]
54+
},
55+
{
56+
"cell_type": "code",
57+
"execution_count": null,
58+
"id": "2",
59+
"metadata": {},
60+
"outputs": [],
61+
"source": [
62+
"rng = np.random.default_rng(42)"
63+
]
64+
},
65+
{
66+
"cell_type": "code",
67+
"execution_count": null,
68+
"id": "3",
69+
"metadata": {},
70+
"outputs": [],
71+
"source": [
72+
"data_shapes = [(1, 1, 512, 620), (1, 1, 300, 335), (1, 1, 512, 512)]\n",
73+
"patch_size = (64, 64)\n",
74+
"\n",
75+
"stratified_patching = StratifiedPatchingStrategy(data_shapes, patch_size, seed=42)\n",
76+
"n_val_patches = int(np.ceil(stratified_patching.n_patches * 0.1)) # 10% of patches\n",
77+
"print(\n",
78+
" f\"Selecting {n_val_patches} validation patches from \"\n",
79+
" f\"{stratified_patching.n_patches} total patches.\"\n",
80+
")\n",
81+
"train_patching, val_patching = create_val_split(stratified_patching, n_val_patches, rng)\n",
82+
"\n",
83+
"train_1 = demo_selected_patches(train_patching, data_shapes, epochs=1)\n",
84+
"train_200 = demo_selected_patches(train_patching, data_shapes, epochs=200)\n",
85+
"val = demo_selected_patches(val_patching, data_shapes, epochs=1)"
86+
]
87+
},
88+
{
89+
"cell_type": "code",
90+
"execution_count": null,
91+
"id": "4",
92+
"metadata": {},
93+
"outputs": [],
94+
"source": [
95+
"fig, axes = plt.subplots(3, len(data_shapes), figsize=(12, 12), constrained_layout=True)\n",
96+
"for i in range(len(data_shapes)):\n",
97+
" axes[0, i].set_title(f\"Image {i}\")\n",
98+
" axes[0, i].imshow(train_1[i][0, 0])\n",
99+
" axes[1, i].imshow(train_200[i][0, 0])\n",
100+
" axes[2, i].imshow(val[i][0, 0])\n",
101+
"axes[0, 0].set_ylabel(\"Train epochs 1\")\n",
102+
"axes[1, 0].set_ylabel(\"Train epochs 200\")\n",
103+
"axes[2, 0].set_ylabel(\"Validation\")"
104+
]
105+
},
106+
{
107+
"cell_type": "code",
108+
"execution_count": null,
109+
"id": "5",
110+
"metadata": {},
111+
"outputs": [],
112+
"source": []
113+
}
114+
],
115+
"metadata": {
116+
"kernelspec": {
117+
"display_name": "careamics",
118+
"language": "python",
119+
"name": "python3"
120+
},
121+
"language_info": {
122+
"codemirror_mode": {
123+
"name": "ipython",
124+
"version": 3
125+
},
126+
"file_extension": ".py",
127+
"mimetype": "text/x-python",
128+
"name": "python",
129+
"nbconvert_exporter": "python",
130+
"pygments_lexer": "ipython3",
131+
"version": "3.11.13"
132+
}
133+
},
134+
"nbformat": 4,
135+
"nbformat_minor": 5
136+
}

src/careamics/dataset_ng/patching_strategies/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
__all__ = [
2+
"FixedPatchingStrategy",
23
"FixedRandomPatchingStrategy",
34
"PatchSpecs",
45
"PatchingStrategy",
@@ -13,6 +14,7 @@
1314
"is_tile_specs",
1415
]
1516

17+
from .fixed_patching import FixedPatchingStrategy
1618
from .patching_strategy_factory import create_patching_strategy
1719
from .patching_strategy_protocol import (
1820
PatchingStrategy,
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
"""A module for a fixed coordinate patching strategy, useful for validation."""
2+
3+
from collections.abc import Sequence
4+
5+
from .patching_strategy_protocol import PatchSpecs
6+
7+
8+
class FixedPatchingStrategy:
9+
"""A simple patching strategy that returns patches from a fixed sequence.
10+
11+
This class implements the `PatchingStrategy` `Protocol`.
12+
"""
13+
14+
def __init__(self, fixed_patch_specs: Sequence[PatchSpecs]):
15+
"""A simple patching strategy that returns patches from a fixed list.
16+
17+
Parameters
18+
----------
19+
fixed_patch_specs: Sequence[PatchSpecs]
20+
A sequence of patch specifications.
21+
"""
22+
self.fixed_patch_specs = fixed_patch_specs
23+
24+
@property
25+
def n_patches(self):
26+
"""
27+
The number of patches that this patching strategy will return.
28+
29+
It also determines the maximum index that can be given to `get_patch_spec`.
30+
"""
31+
return len(self.fixed_patch_specs)
32+
33+
def get_patch_spec(self, index: int) -> PatchSpecs:
34+
"""Return the patch specs for a given index.
35+
36+
Parameters
37+
----------
38+
index : int
39+
A patch index.
40+
41+
Returns
42+
-------
43+
PatchSpecs
44+
A dictionary that specifies a single patch in a series of `ImageStacks`.
45+
"""
46+
if index >= self.n_patches:
47+
raise IndexError(
48+
f"Index {index} out of bounds for FixedRandomPatchingStrategy with "
49+
f"number of patches {self.n_patches}"
50+
)
51+
# simply index the pre-generated patches to get the correct patch
52+
return self.fixed_patch_specs[index]
53+
54+
# Note: this is used by the FileIterSampler
55+
def get_patch_indices(self, data_idx: int) -> Sequence[int]:
56+
"""
57+
Return all patch indices belonging to a specific `image_stack`.
58+
59+
Each `image_stack` corresponds to a given `data_idx`.
60+
61+
Parameters
62+
----------
63+
data_idx : int
64+
An index that corresponds to a given `image_stack`.
65+
66+
Returns
67+
-------
68+
sequence of int
69+
A sequence of patch indices belonging to a particular `image_stack` that
70+
can be used to index the `CAREamicsDataset`.
71+
"""
72+
return [
73+
i
74+
for i, patch_spec in enumerate(self.fixed_patch_specs)
75+
if patch_spec["data_idx"] == data_idx
76+
]

src/careamics/dataset_ng/patching_strategies/stratified_patching.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,6 @@ def n_patches(self) -> int:
9696
sum([sample.n_patches for sample in image]) for image in self.image_patching
9797
)
9898

99-
# TODO: add method to return valid grid coords for removal
100-
10199
def exclude_patches(
102100
self, data_idx: int, sample_idx: int, grid_coords: Sequence[tuple[int, ...]]
103101
):
@@ -194,6 +192,27 @@ def get_patch_indices(self, data_idx: int) -> Sequence[int]:
194192
start = 0 if data_idx == 0 else self.cumulative_image_patches[data_idx - 1]
195193
return np.arange(start, self.cumulative_image_patches[data_idx]).tolist()
196194

195+
def get_included_grid_coords(self) -> dict[tuple[int, int], list[tuple[int, ...]]]:
196+
"""
197+
Get all grid coordinates included in the patching strategy.
198+
199+
If a grid coordinate is not included, a patch can never be selected from the
200+
region `[grid_coord*patch_size, (grid_coord+1)*patch_size]`.
201+
202+
Returns
203+
-------
204+
grid_coords : dict[tuple[int, int], list[tuple, ...]]
205+
The key of the returned dictionary corresponds to the
206+
`(data_idx, sample_idx)` and the values are the corresponding grid coords.
207+
"""
208+
included_grid_coords: dict[tuple[int, int], list[tuple[int, ...]]] = {}
209+
for data_idx, image_patch_list in enumerate(self.image_patching):
210+
for sample_idx, sample_patching in enumerate(image_patch_list):
211+
included_grid_coords[(data_idx, sample_idx)] = (
212+
sample_patching.get_included_grid_coords()
213+
)
214+
return included_grid_coords
215+
197216
def _calc_bins(self) -> tuple[NDArray[np.int_], NDArray[np.int_], NDArray[np.int_]]:
198217
"""
199218
Calculate bins to determine which image and sample a patch index maps to.
@@ -301,7 +320,7 @@ def __init__(
301320
self.areas: dict[tuple[int, ...], int] = {}
302321
self.probs: dict[tuple[int, ...], float]
303322

304-
self.excluded_patches: list[tuple[int, ...]] = []
323+
self.excluded_patches: set[tuple[int, ...]] = set()
305324
self.bin_size: int
306325
self.bins: list[list[tuple[int, ...]]]
307326
self.n_patches: int
@@ -413,7 +432,7 @@ def exclude_patches(self, grid_coords: Sequence[tuple[int, ...]]):
413432
that will be excluded from sampling. The grid starts at (0, 0) and has a
414433
spacing of the given `patch_size`.
415434
"""
416-
self.excluded_patches.extend(grid_coords)
435+
self.excluded_patches.update(grid_coords)
417436
for grid_coord in grid_coords:
418437
d: tuple[Literal[0, 1], ...] = (0, 1)
419438
# exclude the patch from all the sampling regions that cover it
@@ -438,6 +457,21 @@ def exclude_patches(self, grid_coords: Sequence[tuple[int, ...]]):
438457
self._recalculate_sampling()
439458
)
440459

460+
def get_included_grid_coords(self) -> list[tuple[int, ...]]:
461+
"""
462+
Get all the included grid coordinates in the patching strategy.
463+
464+
If a grid coordinate is not included, a patch can never be selected from the
465+
region `[grid_coord*patch_size, (grid_coord+1)*patch_size]`.
466+
467+
Returns
468+
-------
469+
grid_coords : list[tuple, ...]]
470+
The list of included grid coordinates.
471+
"""
472+
grid_coords_all: set[tuple[int, ...]] = set(self.regions.keys())
473+
return list(grid_coords_all.difference(self.excluded_patches))
474+
441475
def _recalculate_sampling(self):
442476
"""
443477
Recalculate how patches will be sampled.
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
"""A module for selecting data to be set aside for validation."""
2+
3+
import numpy as np
4+
5+
from .patching_strategies import (
6+
FixedPatchingStrategy,
7+
PatchSpecs,
8+
StratifiedPatchingStrategy,
9+
)
10+
11+
12+
def create_val_split(
13+
stratified_patching: StratifiedPatchingStrategy,
14+
n_val_patches: int,
15+
rng: np.random.Generator,
16+
) -> tuple[StratifiedPatchingStrategy, FixedPatchingStrategy]:
17+
"""
18+
Create patching strategies for training and validation.
19+
20+
The patches from the training patching strategy will never overlap with the patches
21+
from the validation patching strategy.
22+
23+
Parameters
24+
----------
25+
stratified_patching : StratifiedPatchingStrategy
26+
The patching strategy to select and exclude validation patches from.
27+
n_val_patches: int,
28+
The number of validation patches.
29+
rng : int, optional
30+
An optional seed to ensure the reproducibility of the validation patch choice.
31+
Returns
32+
-------
33+
training_patching_strategy : StratifiedPatchingStrategy
34+
The patching strategy to be used for training. Patches will be sampled in a
35+
stratified way, for each epoch. It excludes all the patches that should be used
36+
for validation.
37+
validation_patching_strategy : FixedPatchingStrategy
38+
The patching strategy to be used for validation. It will return the same patches
39+
every epoch.
40+
"""
41+
patch_size = stratified_patching.patch_size
42+
43+
# validation patches have to lie on this grid
44+
grid_coords = stratified_patching.get_included_grid_coords()
45+
# sample_ids are (data_idx, sample_idx)
46+
sample_ids = list(grid_coords.keys())
47+
val_patch_specs: list[PatchSpecs] = []
48+
49+
# select validation patches
50+
n_patches_per_image = np.array(
51+
[
52+
stratified_patching.image_patching[data_idx][sample_idx].n_patches
53+
for data_idx, sample_idx in sample_ids
54+
]
55+
)
56+
n_selected_image_patches = np.zeros_like(n_patches_per_image)
57+
for _ in range(n_val_patches):
58+
probs = n_patches_per_image / n_patches_per_image.sum()
59+
idx = rng.choice(np.arange(len(n_patches_per_image)), p=probs)
60+
n_selected_image_patches[idx] += 1
61+
n_patches_per_image[idx] -= 1
62+
63+
for idx, n_patches in enumerate(n_selected_image_patches):
64+
65+
data_idx, sample_idx = sample_ids[idx]
66+
# randomly choose the validation patches in the image
67+
coord_indices = rng.choice(
68+
len(grid_coords[(data_idx, sample_idx)]), n_patches, replace=False
69+
)
70+
coords: list[tuple[int, ...]] = [
71+
grid_coords[(data_idx, sample_idx)][coord_idx]
72+
for coord_idx in coord_indices
73+
]
74+
# exclude the chosen validation patches from training
75+
stratified_patching.exclude_patches(data_idx, sample_idx, coords)
76+
77+
# collect the chosen validation patches to create the fixed patching strategy
78+
patch_specs: list[PatchSpecs] = [
79+
{
80+
"data_idx": data_idx,
81+
"sample_idx": sample_idx,
82+
"coords": tuple(np.array(grid_coord) * np.array(patch_size)),
83+
"patch_size": patch_size,
84+
}
85+
for grid_coord in coords
86+
]
87+
val_patch_specs.extend(patch_specs)
88+
89+
val_patching_strategy = FixedPatchingStrategy(val_patch_specs)
90+
return stratified_patching, val_patching_strategy

0 commit comments

Comments
 (0)