Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
9021ff2
torch blur transform separable 1d convolution, slice and volume dimen…
starostka May 14, 2025
ee1e156
add Torch_Blur
Sllambias May 14, 2025
177f398
add: torch bias field
starostka May 15, 2025
36f91ea
change: narrow to 2d slice and 3d volume, ignore 2d w. channels
starostka May 15, 2025
c8a5cc3
add: torch gamma transform
starostka May 16, 2025
dba485b
add: torch motion ghosting
starostka May 16, 2025
3ba6e06
module loads
starostka May 16, 2025
d7fe86e
added: masking, noise, lowres sampling, ringing, spatial deformation
starostka May 22, 2025
0e32a53
Add Wrappers, format code and make minor label + device edits
Sllambias May 28, 2025
6e9ab38
same as before
Sllambias May 28, 2025
1a3c57e
fix notebook
Sllambias May 28, 2025
5c00e44
Merge branch 'main' into blur_transform
Sllambias May 30, 2025
68a8b54
fix formatting
Sllambias May 30, 2025
dce83a1
bump version and torchmetrics
Sllambias May 30, 2025
a35a612
.
Sllambias May 30, 2025
e57f10f
Remove faulty Dice implementation
Sllambias Jun 1, 2025
2294d60
formatting update
Sllambias Jun 2, 2025
7af64fb
3 x 1D separable gaussian kernels
starostka Jun 3, 2025
2979dab
permute coordinate component and spatial coordinates to match (x,y,z)…
starostka Jun 4, 2025
dbc5b97
formatting edits and add final xforms to notebook
Sllambias Jun 4, 2025
3d2bbb2
remove unused import
Sllambias Jun 4, 2025
eb6e2c3
visualization script for investigating pytorch implementations
starostka Jun 4, 2025
a9f4a94
update masking xform
Sllambias Jun 11, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "yucca"
version = "2.3.1"
version = "2.3.2"
authors = [
{ name="Sebastian Llambias", email="llambias@live.com" },
{ name="Asbjørn Munk", email="9844416+asbjrnmunk@users.noreply.github.com" },
Expand Down Expand Up @@ -35,7 +35,7 @@ dependencies = [
"SimpleITK>=2.3.1",
"tqdm>=4.66.2",
"timm>=0.9.8",
"torchmetrics==1.4.0.post0",
"torchmetrics>=1.4.0.post0",
"wandb>=0.16.3",
"weave>=0.39.0",
]
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 0 additions & 1 deletion yucca/documentation/templates/functional_preprocessing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
if __name__ == "__main__":

import re
import os
import numpy as np
Expand Down
1 change: 0 additions & 1 deletion yucca/documentation/templates/functional_training.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
if __name__ == "__main__":

import lightning as L
from yucca.pipeline.configuration.configure_task import TaskConfig
from yucca.pipeline.configuration.configure_paths import get_path_config
Expand Down
468 changes: 468 additions & 0 deletions yucca/documentation/tests/transforms/GPUaugmentations.ipynb

Large diffs are not rendered by default.

230 changes: 230 additions & 0 deletions yucca/documentation/tests/transforms/viz_torch_transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
#! /usr/bin/env python3

# How to run:
# python viz_torch_transforms --datasets-dir /path/to/datasets --output /path/to/output.png
# python viz_torch_transforms --datasets-dir /path/to/datasets --output /path/to/output.png --headless
# Defaults: datasets dir is ~/datasets and output is ./viz_torch_transforms.png

import os
import glob
import nibabel as nib
import matplotlib.pyplot as plt
import argparse
import torch
import numpy as np

from yucca.functional.transforms.torch import (
torch_bias_field,
torch_blur,
torch_gamma,
torch_motion_ghosting,
torch_mask,
torch_additive_noise,
torch_multiplicative_noise,
torch_gibbs_ringing,
torch_simulate_lowres,
torch_spatial,
)

torch.manual_seed(1234)


def show_high_res(data, title):
"""Show a high-resolution view of the data with slice navigation."""
print(f"\nOpening high-resolution view for: {title}")
print(f"Data shape: {data.shape}")

fig, ax = plt.subplots(figsize=(10, 10))
fig.canvas.manager.set_window_title(f"High Resolution: {title}")
current_slice = data.shape[2] // 2

# Show initial image and colorbar
im = ax.imshow(data[:, :, current_slice], cmap="gray")
ax.set_title(f"{title}\nSlice {current_slice} of {data.shape[2]}", pad=20)
# cbar = plt.colorbar(im, ax=ax, label='Intensity')

def update_display():
im.set_data(data[:, :, current_slice])
ax.set_title(f"{title}\nSlice {current_slice} of {data.shape[2]}", pad=20)
# Optionally, update colorbar limits if data range changes
# im.set_clim(vmin=data.min(), vmax=data.max())
fig.canvas.draw_idle()

def on_key(event):
nonlocal current_slice
print(f"Key pressed: {event.key}")
if event.key == "up" and current_slice < data.shape[2] - 1:
current_slice += 1
print(f"Moving to slice {current_slice}")
update_display()
elif event.key == "down" and current_slice > 0:
current_slice -= 1
print(f"Moving to slice {current_slice}")
update_display()

fig.canvas.mpl_connect("key_press_event", on_key)
plt.show()


def on_click(event):
"""Handle click events on the main figure."""
print(f"\nClick event detected at: {event.xdata}, {event.ydata}")

if event.inaxes is None:
print("Click was outside axes")
return

ax = event.inaxes
print(f"Clicked on axes: {ax}")

for idx, (title, data) in enumerate(transforms.items()):
if ax == axes[idx // n_cols, idx % n_cols]:
print(f"Found matching subplot: {title}")
show_high_res(data.numpy(), title)
return

print("No matching subplot found")


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Apply and visualize image transformations")
parser.add_argument("--headless", action="store_true", help="Run without displaying plots")
parser.add_argument(
"--datasets-dir",
type=str,
default=os.path.expanduser("~/datasets"),
help="Path to the datasets directory (default: ~/datasets)",
)
parser.add_argument(
"--output",
type=str,
default="viz_torch_transforms.png",
help="Output image file path (default: viz_torch_transforms.png)",
)
args = parser.parse_args()

nii_files = glob.glob(os.path.join(args.datasets_dir, "**", "*.nii.gz"), recursive=True)
nii_files.sort()

single_sample = nii_files[1]
im = nib.load(single_sample)
data = im.get_fdata()

print("\nImage Analysis:")
print("Data type:", data.dtype)
print("Shape:", data.shape)
print("Value range:", data.min(), "to", data.max())
print("Mean value:", data.mean())
print("Std value:", data.std())
print("Unique values count:", len(np.unique(data)))
print("First few unique values:", np.unique(data)[:10])

# Convert to float32 and scale to [0,1] while preserving relative intensities
imarr = data.astype(np.float32)
data_min = imarr.min()
data_max = imarr.max()
data_range = data_max - data_min

# Scale to [0,1] while preserving relative intensities
imarr = (imarr - data_min) / data_range
imarr = torch.from_numpy(imarr)
# Assume batch and channel dimensions are present
# imarr = imarr.unsqueeze(0).unsqueeze(0)

print("\nNormalized Image Analysis:")
print("Value range:", imarr.min().item(), "to", imarr.max().item())
print("Mean value:", imarr.mean().item())
print("Std value:", imarr.std().item())

transforms = {
"Original": imarr,
"Bias Field": torch_bias_field(imarr.clone(), clip_to_input_range=True),
"Blurred (σ=2.0)": torch_blur(imarr.clone(), sigma=2.0),
"Gamma (0.5–2.0)": torch_gamma(imarr.clone(), gamma_range=(0.5, 2.0), clip_to_input_range=True),
"Gamma (2.0)": torch_gamma(imarr, gamma_range=(2.0, 2.0), clip_to_input_range=True),
"Ghost (α=2, ax=0)": torch_motion_ghosting(imarr.clone(), alpha=2.0, num_reps=4, axis=0, clip_to_input_range=True),
"Ghost (α=2, ax=1)": torch_motion_ghosting(imarr.clone(), alpha=2.0, num_reps=4, axis=1, clip_to_input_range=True),
"Masked (r=0.3)": torch_mask(imarr.clone(), pixel_value=0, ratio=0.3, token_size=[16, 16, 16]),
"Add Noise (σ=0.05)": torch_additive_noise(imarr.clone(), mean=0, sigma=0.05, clip_to_input_range=True),
"Add Noise (σ=0.1)": torch_additive_noise(imarr.clone(), mean=0, sigma=0.1, clip_to_input_range=True),
"Add Noise (σ=0.2)": torch_additive_noise(imarr.clone(), mean=0, sigma=0.2, clip_to_input_range=True),
"Mult Noise (σ=0.05)": torch_multiplicative_noise(imarr.clone(), mean=0, sigma=0.05, clip_to_input_range=True),
"Mult Noise (σ=0.1)": torch_multiplicative_noise(imarr.clone(), mean=0, sigma=0.1, clip_to_input_range=True),
"Mult Noise (σ=0.2)": torch_multiplicative_noise(imarr.clone(), mean=0, sigma=0.2, clip_to_input_range=True),
"Gibbs (axes=0,1,2)": torch_gibbs_ringing(
imarr.clone(), num_sample=64, axes=[0, 1, 2], mode="rect", clip_to_input_range=True
),
"Low Res": torch_simulate_lowres(imarr.clone(), target_shape=(32, 32, 32), clip_to_input_range=True),
"Spatial": torch_spatial(
imarr.clone(),
patch_size=imarr.shape,
p_deform=1.0,
p_rot=1.0,
p_rot_per_axis=1.0,
p_scale=1.0,
alpha=10.0,
sigma=3.0,
x_rot=0.5,
y_rot=0,
z_rot=0,
scale_factor=1.0,
clip_to_input_range=True,
)[
0
], # Get only the image, not the label
}

# Get middle slice for visualization
slice_idx = imarr.shape[-1] // 2 # Changed to use last dimension

# Calculate grid dimensions
n_transforms = len(transforms)
n_cols = 6 # Show 6 images per row
n_rows = (n_transforms + n_cols - 1) // n_cols # Ceiling division

# Create a figure with subplots
fig_width = 2.2 * n_cols
fig_height = 2.2 * n_rows
plt.rcParams["figure.figsize"] = [fig_width, fig_height]
plt.rcParams["figure.dpi"] = 200 # Keep high DPI for quality
fig, axes = plt.subplots(n_rows, n_cols, squeeze=False)

# Plot each transformation
for idx, (title, transformed_data) in enumerate(transforms.items()):
row = idx // n_cols
col = idx % n_cols

# Plot transformed image
if isinstance(transformed_data, tuple):
transformed_data = transformed_data[0]
print(title, transformed_data.shape)
transformed_plot = transformed_data.numpy()
axes[row, col].imshow(transformed_plot[:, :, slice_idx], cmap="gray")
axes[row, col].set_title(title, pad=1, fontsize=7)
axes[row, col].axis("off")

# Hide any unused subplots
for idx in range(len(transforms), n_rows * n_cols):
row = idx // n_cols
col = idx % n_cols
axes[row, col].axis("off")
axes[row, col].set_visible(False)

# Use subplots_adjust for tight packing
plt.subplots_adjust(top=0.92, bottom=0.08, left=0.03, right=0.97, wspace=0.02, hspace=0.15)

# Set window title
fig.canvas.manager.set_window_title("Image Transformations")

# Save the figure
plt.savefig(args.output, dpi=200, bbox_inches="tight")
print(f"\nSaved visualization to: {args.output}")

# Connect the click event
print("\nConnecting click event handler...")
fig.canvas.mpl_connect("button_press_event", on_click)
print("Click event handler connected. Click any image to view in high resolution.")

# Only show the plot if not in headless mode
if not args.headless:
plt.show()
10 changes: 6 additions & 4 deletions yucca/functional/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,10 +296,12 @@ def preprocess_case_for_training_with_label(

if final_target_size is not None:
images, label = pad_case_to_size(case=images, size=final_target_size, label=label)
image_properties["foreground_locations"], image_properties["label_cc_n"], image_properties["label_cc_sizes"] = (
analyze_label(
label=label, enable_connected_components_analysis=enable_cc_analysis, per_class=foreground_locs_per_label
)
(
image_properties["foreground_locations"],
image_properties["label_cc_n"],
image_properties["label_cc_sizes"],
) = analyze_label(
label=label, enable_connected_components_analysis=enable_cc_analysis, per_class=foreground_locs_per_label
)

first_existing_modality = list(set(range(len(images))).difference(missing_modality_idxs))[0]
Expand Down
8 changes: 8 additions & 0 deletions yucca/functional/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,11 @@
from yucca.functional.transforms.masking import mask_batch
from yucca.functional.transforms.spatial import spatial
from yucca.functional.transforms.skeleton import skeleton
from yucca.functional.transforms.torch.blur import torch_blur
from yucca.functional.transforms.torch.bias_field import torch_bias_field
from yucca.functional.transforms.torch.gamma import torch_gamma
from yucca.functional.transforms.torch.motion_ghosting import torch_motion_ghosting
from yucca.functional.transforms.torch.noise import torch_additive_noise, torch_multiplicative_noise
from yucca.functional.transforms.torch.ringing import torch_gibbs_ringing
from yucca.functional.transforms.torch.sampling import torch_simulate_lowres
from yucca.functional.transforms.torch.spatial import torch_spatial
1 change: 0 additions & 1 deletion yucca/functional/transforms/croppad.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ def croppad(
label: np.ndarray = None,
**pad_kwargs,
):

if len(patch_size) == 3:
image, label = croppad_3D_case_from_3D(
image=image,
Expand Down
1 change: 0 additions & 1 deletion yucca/functional/transforms/spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ def spatial(
# Mapping the images to the distorted coordinates
for b in range(image.shape[0]):
for c in range(image.shape[1]):

img_min = image.min()
img_max = image.max()

Expand Down
9 changes: 9 additions & 0 deletions yucca/functional/transforms/torch/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,10 @@
from .croppad import torch_croppad
from .bias_field import torch_bias_field
from .blur import torch_blur
from .gamma import torch_gamma
from .motion_ghosting import torch_motion_ghosting
from .masking import torch_mask
from .noise import torch_additive_noise, torch_multiplicative_noise
from .ringing import torch_gibbs_ringing
from .sampling import torch_simulate_lowres
from .spatial import torch_spatial
37 changes: 37 additions & 0 deletions yucca/functional/transforms/torch/bias_field.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import torch


def torch_bias_field(image: torch.Tensor, clip_to_input_range: bool = False) -> torch.Tensor:
device = image.device
img_min = image.min()
img_max = image.max()

if len(image.shape) == 3:
assert image.ndim == 3, "Expected [H, W, D] tensor"
x, y, z = image.shape
X, Y, Z = torch.meshgrid(
torch.linspace(0, x - 1, x, device=device),
torch.linspace(0, y - 1, y, device=device),
torch.linspace(0, z - 1, z, device=device),
indexing="ij",
)
x0 = torch.randint(0, x, (1,), device=device)
y0 = torch.randint(0, y, (1,), device=device)
z0 = torch.randint(0, z, (1,), device=device)
G = 1 - ((X - x0) ** 2 / (x**2) + (Y - y0) ** 2 / (y**2) + (Z - z0) ** 2 / (z**2))
else:
assert image.ndim == 2, "Expected [H, W] tensor"

x, y = image.shape
X, Y = torch.meshgrid(
torch.linspace(0, x - 1, x, device=device), torch.linspace(0, y - 1, y, device=device), indexing="ij"
)
x0 = torch.randint(0, x, (1,), device=device)
y0 = torch.randint(0, y, (1,), device=device)
G = 1 - ((X - x0) ** 2 / (x**2) + (Y - y0) ** 2 / (y**2))
image = G * image

if clip_to_input_range:
image = torch.clamp(image, min=img_min, max=img_max)

return image
Loading
Loading