Commit 2831eca
authored
Fix: 3D case for StructN2V + N2V2 with refac of median pixel manipulation (#767)
## Description
> [!NOTE]
> **tldr**: The function `median_manipulate_torch` made an assumption
that patches are always 2D when applying the struct mask to calculate
the median of a subpatch. This PR fixes that problem and refactors the
`median_manipulate_torch` function so that smaller units of the code can
be tested.
### Background - why do we need this PR?
N2V2 uses the median pixel in a subpatch to mask the central pixel. When
applying StructN2V, the pixels excluded by the struct mask should also
not be included in the median calculation. In `median_manipulate_torch`
the struct mask is created to exclude the pixels in the median
calculation but it was only implemented to consider the 2D case, the
line below shows where this assumption happened:
https://github.com/CAREamics/careamics/blob/fbcf24d5fe698033ff7fcbef992a437cf459ebd8/src/careamics/transforms/pixel_manipulation_torch.py#L345-L348
I also took the opportunity to refactor `median_manipulate_torch` into a
few smaller functions so that these smaller units of code could be
tested more thoroughly.
### Overview - what changed?
Struct mask creation has been moved to the function
`_create_struct_mask`, it can handle arbitrary number of dimension.
Central pixel masking has been refactored to mirror struct masking to
make the code clearer and the central pixel mask creation happens in
`_create_center_pixel_mask`.
Most of the code in `median_manipulate_torch` was to create coordinates
to extract the subpatches/rois from the patch in a vectorized way. This
code has been refactored and moved to the function
`_get_subpatch_coords`.
### Implementation - how did you implement the changes?
Subpatch coordinates are calculated in a similar way to the original
implementation but uses torch broacasting to add the subpatch center to
a meshgrid of coordinates rather than iterating through the dimensions.
<!-- How did you solve the issue technically? Explain why you chose this
approach and
provide code examples if applicable (e.g. change in the API for users).
-->
## Changes Made
### New features or files
- `_create_struct_mask`
- `_create_center_pixel_mask`
- `_get_subpatch_coords`
### Modified features or files
- `median_manipulate_torch`
## How has this been tested?
New tests for the new functions.
Added an additional parametrisation to `test_median_manipulate_torch`,
which is the argument `apply_struct`. This only tests applying
horizontal StructN2V, but for both 2D and 2D.
## Additional Notes and Examples
Sanity checked the output with this code:
```python
import numpy as np
import torch
import matplotlib.pyplot as plt
from careamics.transforms.pixel_manipulation_torch import median_manipulate_torch
from careamics.transforms.struct_mask_parameters import StructMaskParameters
shape = (2, 64, 64, 64) # BZYX
array = torch.arange(np.prod(shape).item(), dtype=torch.float32).reshape(shape)
mask_pixel_percentage = 0.08
z = 16
b = 0
fig, axes = plt.subplots(2, 2, figsize=(8, 8), constrained_layout=True)
fig.suptitle(f"Batch {b} | z-slice {z}")
axes[0, 0].set_title("Mask")
axes[0, 1].set_title("Manipulated Patch")
for struct_axis in [0, 1]:
manip_median, manip_median_mask = median_manipulate_torch(
array,
mask_pixel_percentage,
struct_params=StructMaskParameters(axis=struct_axis, span=5),
rng=torch.Generator(),
)
axes[struct_axis, 0].imshow(manip_median_mask[b, z])
axes[struct_axis, 1].imshow(manip_median[b, z])
axes[struct_axis, 0].set_ylabel(f"Struct Axis {struct_axis}")
```
<img width="799" height="811" alt="58ff62aa-3e73-4dcc-bd91-35cedbad49a2"
src="https://github.com/user-attachments/assets/5e0b635f-da7f-4d7a-a66f-dc1af31f9379"
/>
---
**Please ensure your PR meets the following requirements:**
- [x] Code builds and passes tests locally, including doctests
- [x] New tests have been added (for bug fixes/features)
- [x] Pre-commit passes
- [ ] PR to the documentation exists (for bug fixes / features)1 parent cf912d5 commit 2831eca
2 files changed
Lines changed: 357 additions & 86 deletions
File tree
- src/careamics/transforms
- tests/transforms
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
291 | 291 | | |
292 | 292 | | |
293 | 293 | | |
294 | | - | |
295 | | - | |
296 | | - | |
297 | | - | |
298 | | - | |
299 | | - | |
| 294 | + | |
| 295 | + | |
| 296 | + | |
| 297 | + | |
| 298 | + | |
| 299 | + | |
300 | 300 | | |
301 | | - | |
302 | | - | |
| 301 | + | |
| 302 | + | |
303 | 303 | | |
304 | | - | |
305 | | - | |
306 | | - | |
307 | | - | |
308 | | - | |
309 | | - | |
310 | | - | |
| 304 | + | |
| 305 | + | |
| 306 | + | |
| 307 | + | |
| 308 | + | |
| 309 | + | |
| 310 | + | |
| 311 | + | |
311 | 312 | | |
312 | | - | |
313 | | - | |
314 | | - | |
315 | | - | |
316 | | - | |
317 | | - | |
318 | | - | |
319 | | - | |
320 | | - | |
321 | | - | |
322 | | - | |
323 | | - | |
324 | | - | |
325 | | - | |
326 | | - | |
327 | | - | |
328 | | - | |
329 | | - | |
330 | | - | |
331 | | - | |
332 | | - | |
333 | | - | |
334 | | - | |
335 | | - | |
336 | | - | |
337 | | - | |
338 | | - | |
339 | | - | |
340 | | - | |
341 | | - | |
342 | | - | |
343 | 313 | | |
344 | | - | |
345 | | - | |
346 | | - | |
347 | | - | |
348 | | - | |
349 | | - | |
350 | | - | |
351 | | - | |
352 | | - | |
353 | | - | |
354 | | - | |
355 | | - | |
356 | | - | |
357 | | - | |
358 | | - | |
359 | | - | |
360 | | - | |
361 | | - | |
362 | | - | |
363 | | - | |
364 | | - | |
365 | | - | |
366 | | - | |
367 | | - | |
368 | | - | |
| 314 | + | |
| 315 | + | |
| 316 | + | |
| 317 | + | |
| 318 | + | |
| 319 | + | |
| 320 | + | |
| 321 | + | |
369 | 322 | | |
370 | | - | |
371 | | - | |
372 | | - | |
373 | | - | |
| 323 | + | |
| 324 | + | |
374 | 325 | | |
| 326 | + | |
375 | 327 | | |
376 | | - | |
377 | | - | |
| 328 | + | |
378 | 329 | | |
379 | 330 | | |
380 | 331 | | |
381 | | - | |
382 | | - | |
| 332 | + | |
| 333 | + | |
383 | 334 | | |
384 | 335 | | |
385 | 336 | | |
386 | | - | |
| 337 | + | |
387 | 338 | | |
388 | 339 | | |
389 | 340 | | |
| 341 | + | |
| 342 | + | |
| 343 | + | |
| 344 | + | |
| 345 | + | |
| 346 | + | |
| 347 | + | |
| 348 | + | |
| 349 | + | |
| 350 | + | |
| 351 | + | |
| 352 | + | |
| 353 | + | |
| 354 | + | |
| 355 | + | |
| 356 | + | |
| 357 | + | |
| 358 | + | |
| 359 | + | |
| 360 | + | |
| 361 | + | |
| 362 | + | |
| 363 | + | |
| 364 | + | |
| 365 | + | |
| 366 | + | |
| 367 | + | |
| 368 | + | |
| 369 | + | |
| 370 | + | |
| 371 | + | |
| 372 | + | |
| 373 | + | |
| 374 | + | |
| 375 | + | |
| 376 | + | |
| 377 | + | |
| 378 | + | |
| 379 | + | |
| 380 | + | |
| 381 | + | |
| 382 | + | |
| 383 | + | |
| 384 | + | |
| 385 | + | |
| 386 | + | |
| 387 | + | |
| 388 | + | |
| 389 | + | |
| 390 | + | |
| 391 | + | |
| 392 | + | |
| 393 | + | |
| 394 | + | |
| 395 | + | |
| 396 | + | |
| 397 | + | |
| 398 | + | |
| 399 | + | |
| 400 | + | |
| 401 | + | |
| 402 | + | |
| 403 | + | |
| 404 | + | |
| 405 | + | |
| 406 | + | |
| 407 | + | |
| 408 | + | |
| 409 | + | |
| 410 | + | |
| 411 | + | |
| 412 | + | |
| 413 | + | |
| 414 | + | |
| 415 | + | |
| 416 | + | |
| 417 | + | |
| 418 | + | |
| 419 | + | |
| 420 | + | |
| 421 | + | |
| 422 | + | |
| 423 | + | |
| 424 | + | |
| 425 | + | |
| 426 | + | |
| 427 | + | |
| 428 | + | |
| 429 | + | |
| 430 | + | |
| 431 | + | |
| 432 | + | |
| 433 | + | |
| 434 | + | |
| 435 | + | |
| 436 | + | |
| 437 | + | |
| 438 | + | |
| 439 | + | |
| 440 | + | |
| 441 | + | |
| 442 | + | |
| 443 | + | |
| 444 | + | |
| 445 | + | |
| 446 | + | |
| 447 | + | |
| 448 | + | |
| 449 | + | |
| 450 | + | |
| 451 | + | |
| 452 | + | |
| 453 | + | |
| 454 | + | |
| 455 | + | |
| 456 | + | |
| 457 | + | |
| 458 | + | |
| 459 | + | |
| 460 | + | |
| 461 | + | |
| 462 | + | |
| 463 | + | |
| 464 | + | |
| 465 | + | |
| 466 | + | |
| 467 | + | |
| 468 | + | |
| 469 | + | |
| 470 | + | |
| 471 | + | |
| 472 | + | |
| 473 | + | |
| 474 | + | |
| 475 | + | |
| 476 | + | |
| 477 | + | |
| 478 | + | |
| 479 | + | |
0 commit comments