Skip to content

Commit 2831eca

Browse files
authored
Fix: 3D case for StructN2V + N2V2 with refac of median pixel manipulation (#767)
## Description > [!NOTE] > **tldr**: The function `median_manipulate_torch` made an assumption that patches are always 2D when applying the struct mask to calculate the median of a subpatch. This PR fixes that problem and refactors the `median_manipulate_torch` function so that smaller units of the code can be tested. ### Background - why do we need this PR? N2V2 uses the median pixel in a subpatch to mask the central pixel. When applying StructN2V, the pixels excluded by the struct mask should also not be included in the median calculation. In `median_manipulate_torch` the struct mask is created to exclude the pixels in the median calculation but it was only implemented to consider the 2D case, the line below shows where this assumption happened: https://github.com/CAREamics/careamics/blob/fbcf24d5fe698033ff7fcbef992a437cf459ebd8/src/careamics/transforms/pixel_manipulation_torch.py#L345-L348 I also took the opportunity to refactor `median_manipulate_torch` into a few smaller functions so that these smaller units of code could be tested more thoroughly. ### Overview - what changed? Struct mask creation has been moved to the function `_create_struct_mask`, it can handle arbitrary number of dimension. Central pixel masking has been refactored to mirror struct masking to make the code clearer and the central pixel mask creation happens in `_create_center_pixel_mask`. Most of the code in `median_manipulate_torch` was to create coordinates to extract the subpatches/rois from the patch in a vectorized way. This code has been refactored and moved to the function `_get_subpatch_coords`. ### Implementation - how did you implement the changes? Subpatch coordinates are calculated in a similar way to the original implementation but uses torch broacasting to add the subpatch center to a meshgrid of coordinates rather than iterating through the dimensions. <!-- How did you solve the issue technically? Explain why you chose this approach and provide code examples if applicable (e.g. change in the API for users). --> ## Changes Made ### New features or files - `_create_struct_mask` - `_create_center_pixel_mask` - `_get_subpatch_coords` ### Modified features or files - `median_manipulate_torch` ## How has this been tested? New tests for the new functions. Added an additional parametrisation to `test_median_manipulate_torch`, which is the argument `apply_struct`. This only tests applying horizontal StructN2V, but for both 2D and 2D. ## Additional Notes and Examples Sanity checked the output with this code: ```python import numpy as np import torch import matplotlib.pyplot as plt from careamics.transforms.pixel_manipulation_torch import median_manipulate_torch from careamics.transforms.struct_mask_parameters import StructMaskParameters shape = (2, 64, 64, 64) # BZYX array = torch.arange(np.prod(shape).item(), dtype=torch.float32).reshape(shape) mask_pixel_percentage = 0.08 z = 16 b = 0 fig, axes = plt.subplots(2, 2, figsize=(8, 8), constrained_layout=True) fig.suptitle(f"Batch {b} | z-slice {z}") axes[0, 0].set_title("Mask") axes[0, 1].set_title("Manipulated Patch") for struct_axis in [0, 1]: manip_median, manip_median_mask = median_manipulate_torch( array, mask_pixel_percentage, struct_params=StructMaskParameters(axis=struct_axis, span=5), rng=torch.Generator(), ) axes[struct_axis, 0].imshow(manip_median_mask[b, z]) axes[struct_axis, 1].imshow(manip_median[b, z]) axes[struct_axis, 0].set_ylabel(f"Struct Axis {struct_axis}") ``` <img width="799" height="811" alt="58ff62aa-3e73-4dcc-bd91-35cedbad49a2" src="https://github.com/user-attachments/assets/5e0b635f-da7f-4d7a-a66f-dc1af31f9379" /> --- **Please ensure your PR meets the following requirements:** - [x] Code builds and passes tests locally, including doctests - [x] New tests have been added (for bug fixes/features) - [x] Pre-commit passes - [ ] PR to the documentation exists (for bug fixes / features)
1 parent cf912d5 commit 2831eca

2 files changed

Lines changed: 357 additions & 86 deletions

File tree

src/careamics/transforms/pixel_manipulation_torch.py

Lines changed: 170 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -291,99 +291,189 @@ def median_manipulate_torch(
291291
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
292292
tuple containing the manipulated patch, the original patch and the mask.
293293
"""
294-
# get the coordinates of the future ROI centers
295-
subpatch_center_coordinates = _get_stratified_coords_torch(
296-
mask_pixel_percentage, batch.shape, rng
297-
).to(
298-
device=batch.device
299-
) # (num_coordinates, batch + num_spatial_dims)
294+
# -- Implementation summary
295+
# 1. Generate coordinates that correspond to the pixels chosen for masking.
296+
# 2. Subpatches are extracted, where the coordinate to mask is at the center.
297+
# 3. The medians of these subpatches are calculated, but we do not want to include
298+
# the original pixel in the calculation so we mask it. In the case of StructN2V,
299+
# we do not include any pixels in the struct mask in the median calculation.
300300

301-
# Calculate the padding value for the input tensor
302-
pad_value = subpatch_size // 2
301+
if rng is None:
302+
rng = torch.Generator(device=batch.device)
303303

304-
# Generate all offsets for the ROIs. Iteration starting from 1 to skip the batch
305-
offsets = torch.meshgrid(
306-
[
307-
torch.arange(-pad_value, pad_value + 1, device=batch.device)
308-
for _ in range(1, subpatch_center_coordinates.shape[1])
309-
],
310-
indexing="ij",
304+
# resulting center coord shape: (num_coordinates, batch + num_spatial_dims)
305+
subpatch_center_coordinates = _get_stratified_coords_torch(
306+
mask_pixel_percentage, batch.shape, rng
307+
)
308+
# pixel coordinates of all the subpatches
309+
# shape: (num_coordinates, subpatch_size, subpatch_size, ...)
310+
subpatch_coords = _get_subpatch_coords(
311+
subpatch_center_coordinates, subpatch_size, batch.shape
311312
)
312-
offsets = torch.stack(
313-
[axis_offset.flatten() for axis_offset in offsets], dim=1
314-
) # (subpatch_size**2, num_spatial_dims)
315-
316-
# Create the list to assemble coordinates of the ROIs centers for each axis
317-
coords_axes = []
318-
# Create the list to assemble the span of coordinates defining the ROIs for each
319-
# axis
320-
coords_expands = []
321-
for d in range(subpatch_center_coordinates.shape[1]):
322-
coords_axes.append(subpatch_center_coordinates[:, d])
323-
if d == 0:
324-
# For batch dimension coordinates are not expanded (no offsets)
325-
coords_expands.append(
326-
subpatch_center_coordinates[:, d]
327-
.unsqueeze(1)
328-
.expand(-1, subpatch_size ** offsets.shape[1])
329-
) # (num_coordinates, subpatch_size**num_spacial_dims)
330-
else:
331-
# For spatial dimensions, coordinates are expanded with offsets, creating
332-
# spans
333-
coords_expands.append(
334-
(
335-
subpatch_center_coordinates[:, d].unsqueeze(1) + offsets[:, d - 1]
336-
).clamp(0, batch.shape[d] - 1)
337-
) # (num_coordinates, subpatch_size**num_spacial_dims)
338-
339-
# create array of rois by indexing the batch with gathered coordinates
340-
rois = batch[
341-
tuple(coords_expands)
342-
] # (num_coordinates, subpatch_size**num_spacial_dims)
343313

344-
if struct_params is not None:
345-
# Create the structN2V mask
346-
h, w = torch.meshgrid(
347-
torch.arange(subpatch_size), torch.arange(subpatch_size), indexing="ij"
348-
)
349-
center_idx = subpatch_size // 2
350-
halfspan = (struct_params.span - 1) // 2
351-
352-
# Determine the axis along which to apply the mask
353-
if struct_params.axis == 0:
354-
center_axis = h
355-
span_axis = w
356-
else:
357-
center_axis = w
358-
span_axis = h
359-
360-
# Create the mask
361-
struct_mask = (
362-
~(
363-
(center_axis == center_idx)
364-
& (span_axis >= center_idx - halfspan)
365-
& (span_axis <= center_idx + halfspan)
366-
)
367-
).flatten()
368-
rois_filtered = rois[:, struct_mask]
314+
# this indexes and stacks all the subpatches along the first dimension
315+
# subpatches shape: (num_coordinates, subpatch_size, subpatch_size, ...)
316+
subpatches = batch[tuple(subpatch_coords)]
317+
318+
ndims = batch.ndim - 1
319+
# subpatch mask to exclude values from median calculation
320+
if struct_params is None:
321+
subpatch_mask = _create_center_pixel_mask(ndims, subpatch_size, batch.device)
369322
else:
370-
# Remove the center pixel value from the rois
371-
center_idx = (subpatch_size ** offsets.shape[1]) // 2
372-
rois_filtered = torch.cat(
373-
[rois[:, :center_idx], rois[:, center_idx + 1 :]], dim=1
323+
subpatch_mask = _create_struct_mask(
324+
ndims, subpatch_size, struct_params, batch.device
374325
)
326+
subpatches_masked = subpatches[:, subpatch_mask]
375327

376-
# compute the medians.
377-
medians = rois_filtered.median(dim=1).values # (num_coordinates,)
328+
medians = subpatches_masked.median(dim=1).values # (num_coordinates,)
378329

379330
# Update the output tensor with medians
380331
output_batch = batch.clone()
381-
output_batch[tuple(coords_axes)] = medians
382-
mask = torch.where(output_batch != batch, 1, 0).to(torch.uint8)
332+
output_batch[tuple(subpatch_center_coordinates.T)] = medians
333+
mask = (batch != output_batch).to(torch.uint8)
383334

384335
if struct_params is not None:
385336
output_batch = _apply_struct_mask_torch(
386-
output_batch, subpatch_center_coordinates, struct_params
337+
output_batch, subpatch_center_coordinates, struct_params, rng
387338
)
388339

389340
return output_batch, mask
341+
342+
343+
def _create_center_pixel_mask(
344+
ndims: int, subpatch_size: int, device: torch.device
345+
) -> torch.Tensor:
346+
"""
347+
Create a mask for the center pixel of a subpatch.
348+
349+
Parameters
350+
----------
351+
ndims : int
352+
The number of dimensions.
353+
subpatch_size : int
354+
The size of one dimension of the subpatch. The created mask must be the same
355+
size as the subpatch. Cannot be an even number.
356+
device : torch.device
357+
Device to create the mask on, e.g. "cuda".
358+
359+
Returns
360+
-------
361+
torch.Tensor
362+
Tensor of bools. False where pixels should be masked and True otherwise.
363+
"""
364+
if subpatch_size % 2 == 0:
365+
raise ValueError("`subpatch` size cannot be even.")
366+
subpatch_shape = (subpatch_size,) * ndims
367+
centre_idx = (subpatch_size // 2,) * ndims
368+
cp_mask = torch.ones(subpatch_shape, dtype=torch.bool, device=device)
369+
cp_mask[centre_idx] = False
370+
return cp_mask
371+
372+
373+
def _create_struct_mask(
374+
ndims: int,
375+
subpatch_size: int,
376+
struct_params: StructMaskParameters,
377+
device: torch.device,
378+
) -> torch.Tensor:
379+
"""
380+
Create the mask for StructN2V.
381+
382+
Parameters
383+
----------
384+
ndims : int
385+
The number of dimensions.
386+
subpatch_size : int
387+
The size of one dimension of the subpatch. The created mask must be the same
388+
size as the subpatch. Cannot be an even number.
389+
struct_params : StructMaskParameters
390+
Parameters for the structN2V mask (axis and span).
391+
device : torch.device
392+
Device to create the mask on, e.g. "cuda".
393+
394+
Returns
395+
-------
396+
torch.Tensor
397+
Tensor of bools. False where pixels should be masked and True otherwise.
398+
"""
399+
if subpatch_size % 2 == 0:
400+
raise ValueError("`subpatch` size cannot be even.")
401+
center_idx = subpatch_size // 2
402+
span_start = (subpatch_size - struct_params.span) // 2
403+
span_end = subpatch_size - span_start # symmetric
404+
span_axis = ndims - 1 - struct_params.axis # e.g. horizontal is the last axis
405+
406+
struct_mask = torch.ones((subpatch_size,) * ndims, dtype=torch.bool, device=device)
407+
# indexes the center unless it is the axis on which the struct mask spans
408+
struct_slice = (
409+
center_idx if d != span_axis else slice(span_start, span_end)
410+
for d in range(ndims)
411+
)
412+
struct_mask[*struct_slice] = False
413+
return struct_mask
414+
415+
416+
def _get_subpatch_coords(
417+
subpatch_centers: torch.Tensor, subpatch_size: int, batch_shape: tuple[int, ...]
418+
) -> torch.Tensor:
419+
"""Get pixel coordinates for subpatches with centers at `subpatch_centers`.
420+
421+
The coordinates are returned in the shape `(D ,N, S, S)` or `(D, N, S, S, S)` for
422+
2D and 3D patches respectively, where `D` is the number of dimension including the
423+
batch dimension, `N` is the number of subpatches, and `S` is the
424+
subpatch size. N is determined from the length of `subpatch_centres`.
425+
426+
If a subpatch would overlap the bounds of the patch, the coordinates are clipped.
427+
This does result in some duplicated coordinates on the boundary of the patch.
428+
429+
Parameters
430+
----------
431+
subpatch_centers : torch.Tensor
432+
Coordinates of the center of a subpatch, including the batch dimension, i.e.
433+
(b, (z), y, x). Has shape (N, D) for N different subpatch centers and D
434+
dimensions.
435+
subpatch_size : int
436+
The size of one dimension of the subpatch. The created mask must be the same
437+
size as the subpatch.
438+
batch_shape : tuple[int, ...]
439+
The shape of the batch that is being processed, i.e. (B ,(Z), Y, X).
440+
441+
Returns
442+
-------
443+
torch.Tensor
444+
The coordinates of every pixel in each subpatch, stacked into the shape
445+
`(D ,N, S, S)` or `(D, N, S, S, S)` for 2D and 3D patches respectively, where
446+
`D` is the number of dimension including the batch dimension, `N` is the number
447+
of subpatches, and `S` is the subpatch size.
448+
"""
449+
device = subpatch_centers.device
450+
ndims = len(batch_shape) - 1 # spatial dimensions
451+
452+
half_size = subpatch_size // 2
453+
# pixel offset from the center of the subpatch, i.e. coords relative to the center
454+
offsets = torch.meshgrid(
455+
[torch.arange(-half_size, half_size + 1, device=device) for _ in range(ndims)],
456+
indexing="ij",
457+
)
458+
# add zero offset for the batch dimension
459+
subpatch_shape = (subpatch_size,) * ndims
460+
offsets = torch.stack(
461+
[torch.zeros(subpatch_shape, dtype=torch.int64, device=device), *offsets], dim=0
462+
)
463+
464+
# now we need to add the offset to the subpatch_centers to get the subpatch coords
465+
# subpatch_shape: (n_centres, ndims + 1)
466+
# offset_shape: (ndims + 1, subpatch_size, subpatch_size, ...)
467+
# we need to add singleton dims to broadcast the tensors
468+
subpatch_centers = subpatch_centers[..., *(torch.newaxis for _ in range(ndims))]
469+
offsets = offsets[torch.newaxis]
470+
471+
# resulting shape: (n_centres, ndims + 1, subpatch_size, subpatch_size, ...)
472+
subpatch_coords = subpatch_centers + offsets
473+
subpatch_coords = torch.swapaxes(subpatch_coords, 0, 1)
474+
# clamp coordinates so they are not outside the bounds of the patch
475+
broadcast_shape = (ndims + 1, *(1 for _ in range(ndims + 1)))
476+
minimum = torch.zeros(broadcast_shape, dtype=torch.int64, device=device)
477+
maximum = torch.tensor(batch_shape, device=device).reshape(broadcast_shape) - 1
478+
subpatch_coords = subpatch_coords.clamp(minimum, maximum)
479+
return subpatch_coords

0 commit comments

Comments
 (0)