Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
4906c4e
feat: add sampling region class for stratified patching
melisande-c Jan 29, 2026
707005d
feat: WIP stratified patching for single image stack - __init__ only
melisande-c Jan 29, 2026
b390c52
feat: sample coords after bin packing regions (bin packing courtesy o…
melisande-c Jan 29, 2026
01c056f
feat: allow removal of patches
melisande-c Jan 29, 2026
50f0400
feat: add StratifiedPatching class
melisande-c Jan 29, 2026
d2c15dd
test: add stratified patching to sanity check tests; update patch cov…
melisande-c Jan 29, 2026
60d9ca3
refac: rename remove_patch to exclude_patch
melisande-c Jan 29, 2026
6b983b0
pref: store precommputed patch bins as attributes
melisande-c Jan 29, 2026
3dd7354
fix: calculate sample index correctly when a sample has zerp patches
melisande-c Jan 29, 2026
da93225
perf: exclude multiple coords at once so rebinning only has to be don…
melisande-c Jan 29, 2026
3a7b005
test: stratified patching exclude patches
melisande-c Jan 29, 2026
9d59c13
docs: docstrings + comments
melisande-c Jan 30, 2026
f7a8c37
refac(sampling recalc): return data from method instead of modifying …
melisande-c Jan 30, 2026
768a3a1
feat: change number of patches to total - removed; test: revert chang…
melisande-c Jan 30, 2026
481feea
feat: change no. of patches to be closer area of image divided by pat…
melisande-c Jan 30, 2026
d8e5a45
feat: add demo notebook
melisande-c Jan 30, 2026
4da930a
docs: fix comment
melisande-c Jan 30, 2026
b57b46a
Merge branch 'main' into mc/feat/stratif-patching
melisande-c Jan 30, 2026
99c123a
Merge branch 'main' into mc/feat/stratif-patching
jdeschamps Feb 2, 2026
28e911f
docs: grammar and typo suggested fixes from PR
melisande-c Feb 3, 2026
598b5b4
fix: duplicated calc; docs: grammar fix
melisande-c Feb 3, 2026
16003cd
docs: additional comments and docs from PR comments
melisande-c Feb 3, 2026
abee581
docs: add comment structure overview from PR description
melisande-c Feb 3, 2026
0ad9919
Merge branch 'main' into mc/feat/stratif-patching
melisande-c Feb 3, 2026
6d0c2a9
Apply suggestions from code review
melisande-c Feb 4, 2026
58acc2b
Merge branch 'main' into mc/feat/stratif-patching
jdeschamps Feb 4, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
274 changes: 274 additions & 0 deletions src/careamics/dataset_ng/demos/stratified_patching.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,274 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "0",
"metadata": {},
"outputs": [],
"source": [
"from collections.abc import Sequence\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"from numpy.typing import NDArray\n",
"\n",
"from careamics.dataset_ng.patching_strategies import (\n",
" PatchingStrategy,\n",
" RandomPatchingStrategy,\n",
" StratifiedPatchingStrategy,\n",
")"
]
},
{
"cell_type": "markdown",
"id": "1",
"metadata": {},
"source": [
"# Demoing the Stratified Patching Strategy"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2",
"metadata": {},
"outputs": [],
"source": [
"def demo_selected_patches(\n",
" patching_strategy: PatchingStrategy,\n",
" data_shapes: Sequence[Sequence[int]],\n",
" epochs: int,\n",
") -> Sequence[NDArray[np.int_]]:\n",
" \"\"\"Create a map where all the patches have been selected from.\n",
"\n",
" Every time a patch is selected that area is incremented by 1.\n",
" \"\"\"\n",
" tracking_arrays = [np.zeros(shape, dtype=int) for shape in data_shapes]\n",
" for _ in range(epochs):\n",
" for index in range(patching_strategy.n_patches):\n",
" patch_spec = patching_strategy.get_patch_spec(index)\n",
" data_idx = patch_spec[\"data_idx\"]\n",
" sample_idx = patch_spec[\"sample_idx\"]\n",
" coord = patch_spec[\"coords\"]\n",
" patch_size = patch_spec[\"patch_size\"]\n",
"\n",
" patch_slice = [\n",
" slice(c, c + ps) for c, ps in zip(coord, patch_size, strict=True)\n",
" ]\n",
" tracking_arrays[data_idx][sample_idx, ..., *patch_slice] += 1\n",
" return tracking_arrays"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3",
"metadata": {},
"outputs": [],
"source": [
"seed = 42"
]
},
{
"cell_type": "markdown",
"id": "4",
"metadata": {},
"source": [
"## Comparing the Stratified Patching Strategy to the Random Strategy"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5",
"metadata": {},
"outputs": [],
"source": [
"data_shapes = [(1, 1, 512, 620)]\n",
"patch_size = (64, 64)\n",
"\n",
"stratified_patching = StratifiedPatchingStrategy(data_shapes, patch_size, seed=42)\n",
"random_patching = RandomPatchingStrategy(data_shapes, patch_size, seed=42)\n",
"\n",
"epochs = 1\n",
"stratified_selected = demo_selected_patches(stratified_patching, data_shapes, epochs)\n",
"random_selected = demo_selected_patches(random_patching, data_shapes, epochs)\n",
"\n",
"fig, axes = plt.subplots(1, 2, figsize=(12, 6))\n",
"fig.suptitle(f\"Epochs: {epochs}\")\n",
"axes[0].imshow(stratified_selected[0][0, 0])\n",
"axes[0].set_title(\"Stratified Patching\")\n",
"axes[1].imshow(random_selected[0][0, 0])\n",
"axes[1].set_title(\"Random Patching\")\n",
"fig.tight_layout()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6",
"metadata": {},
"outputs": [],
"source": [
"epochs = 200\n",
"stratified_selected = demo_selected_patches(stratified_patching, data_shapes, epochs)\n",
"random_selected = demo_selected_patches(random_patching, data_shapes, epochs)\n",
"\n",
"fig, axes = plt.subplots(1, 2, figsize=(12, 6))\n",
"fig.suptitle(f\"Epochs: {epochs}\")\n",
"axes[0].imshow(stratified_selected[0][0, 0])\n",
"axes[0].set_title(\"Stratified Patching\")\n",
"axes[1].imshow(random_selected[0][0, 0])\n",
"axes[1].set_title(\"Random Patching\")\n",
"fig.tight_layout()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7",
"metadata": {},
"outputs": [],
"source": [
"mean = np.mean(random_selected[0]/epochs)\n",
"std = np.std(random_selected[0]/epochs)\n",
"print(\"--- Random Strategy ---\")\n",
"print(\"Expected value that a pixel is selected in an epoch\")\n",
"print(f\"Mean: {mean:.3f}, StdDev: {std:.3f}\")\n",
"print(\"\\n\")\n",
"\n",
"mean = np.mean(stratified_selected[0]/epochs)\n",
"std = np.std(stratified_selected[0]/epochs)\n",
"print(\"--- Stratified Strategy ---\")\n",
"print(\"Expected value that a pixel is selected in an epoch\")\n",
"print(f\"Mean: {mean:.3f}, StdDev: {std:.3f}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8",
"metadata": {},
"outputs": [],
"source": [
"centre_slice = [slice(ps, -ps) for ps in patch_size]\n",
"\n",
"mean = np.mean(random_selected[0][..., *centre_slice]/epochs)\n",
"std = np.std(random_selected[0][..., *centre_slice]/epochs)\n",
"print(\"--- Random Strategy ---\")\n",
"print(\"Expected value that a central pixel is selected in an epoch\")\n",
"print(f\"Mean: {mean:.3f}, StdDev: {std:.3f}\")\n",
"print(\"\\n\")\n",
"\n",
"mean = np.mean(stratified_selected[0][..., *centre_slice]/epochs)\n",
"std = np.std(stratified_selected[0][..., *centre_slice]/epochs)\n",
"print(\"--- Stratified Strategy ---\")\n",
"print(\"Expected value that a central pixel is selected in an epoch\")\n",
"print(f\"Mean: {mean:.3f}, StdDev: {std:.3f}\")"
]
},
{
"cell_type": "markdown",
"id": "9",
"metadata": {},
"source": [
"## Demo patch exclusion\n",
"\n",
"Excluded patches have to lie on the grid which has a grid point on (0, 0) and has a \n",
"spacing equal to the chosen patch size"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "10",
"metadata": {},
"outputs": [],
"source": [
"# chose patches to exclude and make mask\n",
"\n",
"exclude_patches = [(3, 2), (5, 6), (4, 6), (2, 8)]\n",
"exlc_mask = np.zeros(data_shapes[0], dtype=bool)\n",
"for grid_coord in exclude_patches:\n",
" patch_slice = [\n",
" slice(c * ps, (c + 1) * ps)\n",
" for c, ps in zip(grid_coord, patch_size, strict=True)\n",
" ]\n",
" exlc_mask[..., *patch_slice] = True\n",
"plt.imshow(exlc_mask[0, 0])\n",
"plt.title(\"Excluded patches map\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "11",
"metadata": {},
"outputs": [],
"source": [
"# exclude patches\n",
"\n",
"stratified_patching.exclude_patches(\n",
" data_idx=0, sample_idx=0, grid_coords=exclude_patches\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "12",
"metadata": {},
"outputs": [],
"source": [
"# plot results\n",
"\n",
"stratified_1 = demo_selected_patches(stratified_patching, data_shapes, epochs=1)\n",
"stratified_200 = demo_selected_patches(stratified_patching, data_shapes, epochs=200)\n",
"\n",
"fig, axes = plt.subplots(1, 2, figsize=(12, 6))\n",
"fig.suptitle(\"Stratified Patching\")\n",
"axes[0].imshow(stratified_1[0][0, 0])\n",
"axes[0].set_title(\"Epochs: 1\")\n",
"axes[1].imshow(stratified_200[0][0, 0])\n",
"axes[1].set_title(\"Epochs: 200\")\n",
"fig.tight_layout()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "13",
"metadata": {},
"outputs": [],
"source": [
"mean = np.mean(stratified_200[0][~exlc_mask]/200)\n",
"std = np.std(stratified_200[0][~exlc_mask]/200)\n",
"print(\"--- Stratified Strategy ---\")\n",
"print(\"Expected value that an included pixel is selected in an epoch\")\n",
"print(f\"Mean: {mean:.3f}, StdDev: {std:.3f}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "careamics",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
2 changes: 2 additions & 0 deletions src/careamics/dataset_ng/patching_strategies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"RandomPatchingStrategy",
"RegionSpecs",
"SequentialPatchingStrategy",
"StratifiedPatchingStrategy",
"TileSpecs",
"TilingStrategy",
"WholeSamplePatchingStrategy",
Expand All @@ -22,5 +23,6 @@
)
from .random_patching import FixedRandomPatchingStrategy, RandomPatchingStrategy
from .sequential_patching import SequentialPatchingStrategy
from .stratified_patching import StratifiedPatchingStrategy
from .tiling_strategy import TilingStrategy
from .whole_sample import WholeSamplePatchingStrategy
Loading