Skip to content
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
9021ff2
torch blur transform separable 1d convolution, slice and volume dimen…
starostka May 14, 2025
ee1e156
add Torch_Blur
Sllambias May 14, 2025
177f398
add: torch bias field
starostka May 15, 2025
36f91ea
change: narrow to 2d slice and 3d volume, ignore 2d w. channels
starostka May 15, 2025
c8a5cc3
add: torch gamma transform
starostka May 16, 2025
dba485b
add: torch motion ghosting
starostka May 16, 2025
3ba6e06
module loads
starostka May 16, 2025
d7fe86e
added: masking, noise, lowres sampling, ringing, spatial deformation
starostka May 22, 2025
0e32a53
Add Wrappers, format code and make minor label + device edits
Sllambias May 28, 2025
6e9ab38
same as before
Sllambias May 28, 2025
1a3c57e
fix notebook
Sllambias May 28, 2025
5c00e44
Merge branch 'main' into blur_transform
Sllambias May 30, 2025
68a8b54
fix formatting
Sllambias May 30, 2025
dce83a1
bump version and torchmetrics
Sllambias May 30, 2025
a35a612
.
Sllambias May 30, 2025
e57f10f
Remove faulty Dice implementation
Sllambias Jun 1, 2025
2294d60
formatting update
Sllambias Jun 2, 2025
7af64fb
3 x 1D separable gaussian kernels
starostka Jun 3, 2025
2979dab
permute coordinate component and spatial coordinates to match (x,y,z)…
starostka Jun 4, 2025
dbc5b97
formatting edits and add final xforms to notebook
Sllambias Jun 4, 2025
3d2bbb2
remove unused import
Sllambias Jun 4, 2025
eb6e2c3
visualization script for investigating pytorch implementations
starostka Jun 4, 2025
a9f4a94
update masking xform
Sllambias Jun 11, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "yucca"
version = "2.3.1"
version = "2.3.2"
authors = [
{ name="Sebastian Llambias", email="llambias@live.com" },
{ name="Asbjørn Munk", email="9844416+asbjrnmunk@users.noreply.github.com" },
Expand Down Expand Up @@ -35,7 +35,7 @@ dependencies = [
"SimpleITK>=2.3.1",
"tqdm>=4.66.2",
"timm>=0.9.8",
"torchmetrics==1.4.0.post0",
"torchmetrics>=1.4.0.post0",
"wandb>=0.16.3",
"weave>=0.39.0",
]
Expand Down
1 change: 0 additions & 1 deletion yucca/documentation/templates/functional_preprocessing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
if __name__ == "__main__":

import re
import os
import numpy as np
Expand Down
1 change: 0 additions & 1 deletion yucca/documentation/templates/functional_training.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
if __name__ == "__main__":

import lightning as L
from yucca.pipeline.configuration.configure_task import TaskConfig
from yucca.pipeline.configuration.configure_paths import get_path_config
Expand Down
468 changes: 468 additions & 0 deletions yucca/documentation/tests/transforms/GPUaugmentations.ipynb

Large diffs are not rendered by default.

10 changes: 6 additions & 4 deletions yucca/functional/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,10 +296,12 @@ def preprocess_case_for_training_with_label(

if final_target_size is not None:
images, label = pad_case_to_size(case=images, size=final_target_size, label=label)
image_properties["foreground_locations"], image_properties["label_cc_n"], image_properties["label_cc_sizes"] = (
analyze_label(
label=label, enable_connected_components_analysis=enable_cc_analysis, per_class=foreground_locs_per_label
)
(
image_properties["foreground_locations"],
image_properties["label_cc_n"],
image_properties["label_cc_sizes"],
) = analyze_label(
label=label, enable_connected_components_analysis=enable_cc_analysis, per_class=foreground_locs_per_label
)

first_existing_modality = list(set(range(len(images))).difference(missing_modality_idxs))[0]
Expand Down
8 changes: 8 additions & 0 deletions yucca/functional/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,11 @@
from yucca.functional.transforms.masking import mask_batch
from yucca.functional.transforms.spatial import spatial
from yucca.functional.transforms.skeleton import skeleton
from yucca.functional.transforms.torch.blur import torch_blur
from yucca.functional.transforms.torch.bias_field import torch_bias_field
from yucca.functional.transforms.torch.gamma import torch_gamma
from yucca.functional.transforms.torch.motion_ghosting import torch_motion_ghosting
from yucca.functional.transforms.torch.noise import torch_additive_noise, torch_multiplicative_noise
from yucca.functional.transforms.torch.ringing import torch_gibbs_ringing
from yucca.functional.transforms.torch.sampling import torch_simulate_lowres
from yucca.functional.transforms.torch.spatial import torch_spatial
1 change: 0 additions & 1 deletion yucca/functional/transforms/croppad.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ def croppad(
label: np.ndarray = None,
**pad_kwargs,
):

if len(patch_size) == 3:
image, label = croppad_3D_case_from_3D(
image=image,
Expand Down
1 change: 0 additions & 1 deletion yucca/functional/transforms/spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ def spatial(
# Mapping the images to the distorted coordinates
for b in range(image.shape[0]):
for c in range(image.shape[1]):

img_min = image.min()
img_max = image.max()

Expand Down
9 changes: 9 additions & 0 deletions yucca/functional/transforms/torch/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,10 @@
from .croppad import torch_croppad
from .bias_field import torch_bias_field
from .blur import torch_blur
from .gamma import torch_gamma
from .motion_ghosting import torch_motion_ghosting
from .masking import torch_mask
from .noise import torch_additive_noise, torch_multiplicative_noise
from .ringing import torch_gibbs_ringing
from .sampling import torch_simulate_lowres
from .spatial import torch_spatial
37 changes: 37 additions & 0 deletions yucca/functional/transforms/torch/bias_field.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import torch


def torch_bias_field(image: torch.Tensor, clip_to_input_range: bool = False) -> torch.Tensor:
device = image.device
img_min = image.min()
img_max = image.max()

if len(image.shape) == 3:
assert image.ndim == 3, "Expected [H, W, D] tensor"
x, y, z = image.shape
X, Y, Z = torch.meshgrid(
torch.linspace(0, x - 1, x, device=device),
torch.linspace(0, y - 1, y, device=device),
torch.linspace(0, z - 1, z, device=device),
indexing="ij",
)
x0 = torch.randint(0, x, (1,), device=device)
y0 = torch.randint(0, y, (1,), device=device)
z0 = torch.randint(0, z, (1,), device=device)
G = 1 - ((X - x0) ** 2 / (x**2) + (Y - y0) ** 2 / (y**2) + (Z - z0) ** 2 / (z**2))
else:
assert image.ndim == 2, "Expected [H, W] tensor"

x, y = image.shape
X, Y = torch.meshgrid(
torch.linspace(0, x - 1, x, device=device), torch.linspace(0, y - 1, y, device=device), indexing="ij"
)
x0 = torch.randint(0, x, (1,), device=device)
y0 = torch.randint(0, y, (1,), device=device)
G = 1 - ((X - x0) ** 2 / (x**2) + (Y - y0) ** 2 / (y**2))
image = G * image

if clip_to_input_range:
image = torch.clamp(image, min=img_min, max=img_max)

return image
56 changes: 56 additions & 0 deletions yucca/functional/transforms/torch/blur.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import torch
import torchvision.transforms.functional as TF


def torch_blur(image: torch.tensor, sigma: float, clip_to_input_range: bool = True) -> torch.tensor:
img_min = image.min()
img_max = image.max()

if image.ndim == 2:
image = blur_2D(image, sigma)
elif image.ndim == 3:
image = blur_3D(image, sigma)
else:
raise ValueError(f"Unsupported image shape for blur: {image.shape}")

if clip_to_input_range:
image = torch.clamp(image, min=img_min, max=img_max)
return image


def blur_2D(image: torch.tensor, sigma: float) -> torch.tensor:
assert image.ndim == 2, "Expected [H, W] tensor"

image = image.unsqueeze(0).unsqueeze(0)
image = TF.gaussian_blur(image, kernel_size=int(2 * round(3 * sigma) + 1), sigma=sigma)
return image.squeeze(0).squeeze(0)


def blur_3D(image: torch.Tensor, sigma: float) -> torch.Tensor:
assert image.ndim == 3, "Expected [D, H, W] tensor"

kernel_size = int(2 * round(3 * sigma) + 1)
if kernel_size % 2 == 0:
kernel_size += 1

# : filter single-channel, [out_channels, in_channels, kernel_size]
coords = torch.arange(kernel_size) - kernel_size // 2
kernel = torch.exp(-(coords.float() ** 2) / (2 * sigma**2))
kernel = kernel / kernel.sum()
kernel = kernel.view(1, 1, -1)

# the permutation ordering is moving the target spatial dimension slice to the last dimension since
# this is where conv1d applies the filter.
volume = image
for axis in range(3):
permute_order = list(range(3))
permute_order[axis], permute_order[-1] = permute_order[-1], permute_order[axis]
volume = volume.permute(permute_order).contiguous()

shape = volume.shape
volume = volume.view(-1, 1, shape[-1])
volume = torch.nn.functional.conv1d(volume, kernel.to(volume.device), padding=kernel.shape[-1] // 2)
volume = volume.view(shape)
volume = volume.permute(permute_order).contiguous()

return volume
12 changes: 6 additions & 6 deletions yucca/functional/transforms/torch/croppad.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ def croppad_3D_case_from_3D(
target_label_shape,
**pad_kwargs,
):
image_out = torch.zeros(target_image_shape)
label_out = torch.zeros(target_label_shape)
image_out = torch.zeros(target_image_shape, device=image.device)
label_out = torch.zeros(target_label_shape, device=image.device)

# First we pad to ensure min size is met
to_pad = []
Expand Down Expand Up @@ -175,8 +175,8 @@ def croppad_2D_case_from_3D(
For 3D we want to first select a slice from the first dimension, i.e. volume[idx, :, :],
then pad or crop as necessary.
"""
image_out = torch.zeros(target_image_shape)
label_out = torch.zeros(target_label_shape)
image_out = torch.zeros(target_image_shape, device=image.device)
label_out = torch.zeros(target_label_shape, device=image.device)

# First we pad to ensure min size is met
to_pad = []
Expand Down Expand Up @@ -273,8 +273,8 @@ def croppad_2D_case_from_2D(
For 3D we want to first select a slice from the first dimension, i.e. volume[idx, :, :],
then pad or crop as necessary.
"""
image_out = torch.zeros(target_image_shape)
label_out = torch.zeros(target_label_shape)
image_out = torch.zeros(target_image_shape, device=image.device)
label_out = torch.zeros(target_label_shape, device=image.device)

# First we pad to ensure min size is met
to_pad = []
Expand Down
45 changes: 45 additions & 0 deletions yucca/functional/transforms/torch/gamma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import torch


def torch_gamma(
image: torch.Tensor,
gamma_range=(0.5, 2),
invert_image=False,
epsilon=1e-7,
per_channel=False,
clip_to_input_range=False,
) -> torch.Tensor:
if invert_image:
image = -image

if not per_channel:
if torch.rand(1).item() < 0.5 and gamma_range[0] < 1:
gamma = torch.rand(1).item() * (1 - gamma_range[0]) + gamma_range[0]
else:
gamma = torch.rand(1).item() * (gamma_range[1] - max(gamma_range[0], 1)) + max(gamma_range[0], 1)

img_min = image.min()
img_max = image.max()
img_range = img_max - img_min

image = torch.pow(((image - img_min) / (img_range + epsilon)), gamma) * img_range + img_min

if clip_to_input_range:
image = torch.clamp(image, min=img_min, max=img_max)
else:
for c in range(image.shape[0]):
if torch.rand(1).item() < 0.5 and gamma_range[0] < 1:
gamma = torch.rand(1).item() * (1 - gamma_range[0]) + gamma_range[0]
else:
gamma = torch.rand(1).item() * (gamma_range[1] - max(gamma_range[0], 1)) + max(gamma_range[0], 1)

img_min = image[c].min()
img_max = image[c].max()
img_range = img_max - img_min

image[c] = torch.pow(((image[c] - img_min) / (img_range + epsilon)), gamma) * (img_range + epsilon) + img_min

if clip_to_input_range:
image[c] = torch.clamp(image[c], min=img_min, max=img_max)

return image
36 changes: 36 additions & 0 deletions yucca/functional/transforms/torch/masking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import torch


def torch_mask(image: torch.Tensor, pixel_value: float, ratio: float, token_size: list[int]) -> torch.Tensor:
"""
We need to mask image over all channels thus input should be 4d tensor of shape (c, x, y, z) or 3d tensor of shape (c, x, y)
"""

input_shape = image.shape[1:] # spatial dims

if len(token_size) == 1:
token_size *= len(input_shape)
assert len(input_shape) == len(
token_size
), f"mask token size not compatible with input data — token: {token_size}, image shape: {input_shape}"

input_shape_tensor = image.new_tensor(input_shape, dtype=torch.int)
token_size_tensor = image.new_tensor(token_size, dtype=torch.int)
grid_dims = torch.ceil(input_shape_tensor / token_size_tensor).to(dtype=torch.int)
grid_size = torch.prod(grid_dims).item()

grid_flat = image.new_ones(grid_size)

grid_flat[: int(grid_size * ratio)] = 0
grid_flat = grid_flat[torch.randperm(grid_size, device=image.device)]

grid = grid_flat.view(*grid_dims)

for dim, size in enumerate(token_size):
grid = grid.repeat_interleave(size, dim=dim)

slices = tuple(slice(0, s) for s in input_shape)
mask = grid[slices]

image[:, mask == 0] = pixel_value
return image, mask
43 changes: 43 additions & 0 deletions yucca/functional/transforms/torch/motion_ghosting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import torch


def torch_motion_ghosting(
image: torch.Tensor,
alpha: float,
num_reps: int,
axis: int,
clip_to_input_range: bool = False,
) -> torch.Tensor:
img_min = image.min()
img_max = image.max()
m = min(0, img_min.item())
image = image + abs(m)

if image.ndim == 3:
image = torch.fft.fftn(image, dim=[0, 1, 2])

if axis == 0:
image[0:-1:num_reps, :, :] = alpha * image[0:-1:num_reps, :, :]
elif axis == 1:
image[:, 0:-1:num_reps, :] = alpha * image[:, 0:-1:num_reps, :]
else:
image[:, :, 0:-1:num_reps] = alpha * image[:, :, 0:-1:num_reps]

image = torch.abs(torch.fft.ifftn(image, dim=[0, 1, 2]))

elif image.ndim == 2:
image = torch.fft.fftn(image, dim=[0, 1])

if axis == 0:
image[0:-1:num_reps, :] = alpha * image[0:-1:num_reps, :]
else:
image[:, 0:-1:num_reps] = alpha * image[:, 0:-1:num_reps]

image = torch.abs(torch.fft.ifftn(image, dim=[0, 1]))

image = image - abs(m)

if clip_to_input_range:
image = torch.clamp(image, min=img_min, max=img_max)

return image
22 changes: 22 additions & 0 deletions yucca/functional/transforms/torch/noise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import torch


def torch_additive_noise(image, mean, sigma, clip_to_input_range: bool = False):
# J = I+n
img_min = image.min()
img_max = image.max()
image = image + torch.normal(mean, sigma, image.shape, device=image.device)
if clip_to_input_range:
image = torch.clamp(image, min=img_min, max=img_max)
return image


def torch_multiplicative_noise(image, mean, sigma, clip_to_input_range: bool = False):
# J = I + I*n
img_min = image.min()
img_max = image.max()
gauss = torch.normal(mean, sigma, image.shape, device=image.device)
image = image + image * gauss
if clip_to_input_range:
image = torch.clamp(image, min=img_min, max=img_max)
return image
Loading