Skip to content

Commit d4c0623

Browse files
starostkaSllambias
andauthored
PyTorch Transformations (#215)
* torch blur transform separable 1d convolution, slice and volume dimensions * add Torch_Blur * add: torch bias field * change: narrow to 2d slice and 3d volume, ignore 2d w. channels * add: torch gamma transform * add: torch motion ghosting * module loads * added: masking, noise, lowres sampling, ringing, spatial deformation * Add Wrappers, format code and make minor label + device edits * same as before * fix notebook * fix formatting * bump version and torchmetrics * . * Remove faulty Dice implementation * formatting update * 3 x 1D separable gaussian kernels * permute coordinate component and spatial coordinates to match (x,y,z) as expected by torch grid_sample * formatting edits and add final xforms to notebook * remove unused import * visualization script for investigating pytorch implementations * update masking xform --------- Co-authored-by: LlambiasMBP <llambias@live.com> Co-authored-by: zcr545 <snl@di.ku.dk>
1 parent 9807abf commit d4c0623

35 files changed

Lines changed: 1714 additions & 50 deletions

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "yucca"
3-
version = "2.3.1"
3+
version = "2.3.2"
44
authors = [
55
{ name="Sebastian Llambias", email="llambias@live.com" },
66
{ name="Asbjørn Munk", email="9844416+asbjrnmunk@users.noreply.github.com" },
@@ -35,7 +35,7 @@ dependencies = [
3535
"SimpleITK>=2.3.1",
3636
"tqdm>=4.66.2",
3737
"timm>=0.9.8",
38-
"torchmetrics==1.4.0.post0",
38+
"torchmetrics>=1.4.0.post0",
3939
"wandb>=0.16.3",
4040
"weave>=0.39.0",
4141
]
1.16 MB
Loading

yucca/documentation/templates/functional_preprocessing.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
if __name__ == "__main__":
2-
32
import re
43
import os
54
import numpy as np

yucca/documentation/templates/functional_training.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
if __name__ == "__main__":
2-
32
import lightning as L
43
from yucca.pipeline.configuration.configure_task import TaskConfig
54
from yucca.pipeline.configuration.configure_paths import get_path_config

yucca/documentation/tests/transforms/GPUaugmentations.ipynb

Lines changed: 468 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
#! /usr/bin/env python3
2+
3+
# How to run:
4+
# python viz_torch_transforms --datasets-dir /path/to/datasets --output /path/to/output.png
5+
# python viz_torch_transforms --datasets-dir /path/to/datasets --output /path/to/output.png --headless
6+
# Defaults: datasets dir is ~/datasets and output is ./viz_torch_transforms.png
7+
8+
import os
9+
import glob
10+
import nibabel as nib
11+
import matplotlib.pyplot as plt
12+
import argparse
13+
import torch
14+
import numpy as np
15+
16+
from yucca.functional.transforms.torch import (
17+
torch_bias_field,
18+
torch_blur,
19+
torch_gamma,
20+
torch_motion_ghosting,
21+
torch_mask,
22+
torch_additive_noise,
23+
torch_multiplicative_noise,
24+
torch_gibbs_ringing,
25+
torch_simulate_lowres,
26+
torch_spatial,
27+
)
28+
29+
torch.manual_seed(1234)
30+
31+
32+
def show_high_res(data, title):
33+
"""Show a high-resolution view of the data with slice navigation."""
34+
print(f"\nOpening high-resolution view for: {title}")
35+
print(f"Data shape: {data.shape}")
36+
37+
fig, ax = plt.subplots(figsize=(10, 10))
38+
fig.canvas.manager.set_window_title(f"High Resolution: {title}")
39+
current_slice = data.shape[2] // 2
40+
41+
# Show initial image and colorbar
42+
im = ax.imshow(data[:, :, current_slice], cmap="gray")
43+
ax.set_title(f"{title}\nSlice {current_slice} of {data.shape[2]}", pad=20)
44+
# cbar = plt.colorbar(im, ax=ax, label='Intensity')
45+
46+
def update_display():
47+
im.set_data(data[:, :, current_slice])
48+
ax.set_title(f"{title}\nSlice {current_slice} of {data.shape[2]}", pad=20)
49+
# Optionally, update colorbar limits if data range changes
50+
# im.set_clim(vmin=data.min(), vmax=data.max())
51+
fig.canvas.draw_idle()
52+
53+
def on_key(event):
54+
nonlocal current_slice
55+
print(f"Key pressed: {event.key}")
56+
if event.key == "up" and current_slice < data.shape[2] - 1:
57+
current_slice += 1
58+
print(f"Moving to slice {current_slice}")
59+
update_display()
60+
elif event.key == "down" and current_slice > 0:
61+
current_slice -= 1
62+
print(f"Moving to slice {current_slice}")
63+
update_display()
64+
65+
fig.canvas.mpl_connect("key_press_event", on_key)
66+
plt.show()
67+
68+
69+
def on_click(event):
70+
"""Handle click events on the main figure."""
71+
print(f"\nClick event detected at: {event.xdata}, {event.ydata}")
72+
73+
if event.inaxes is None:
74+
print("Click was outside axes")
75+
return
76+
77+
ax = event.inaxes
78+
print(f"Clicked on axes: {ax}")
79+
80+
for idx, (title, data) in enumerate(transforms.items()):
81+
if ax == axes[idx // n_cols, idx % n_cols]:
82+
print(f"Found matching subplot: {title}")
83+
show_high_res(data.numpy(), title)
84+
return
85+
86+
print("No matching subplot found")
87+
88+
89+
if __name__ == "__main__":
90+
parser = argparse.ArgumentParser(description="Apply and visualize image transformations")
91+
parser.add_argument("--headless", action="store_true", help="Run without displaying plots")
92+
parser.add_argument(
93+
"--datasets-dir",
94+
type=str,
95+
default=os.path.expanduser("~/datasets"),
96+
help="Path to the datasets directory (default: ~/datasets)",
97+
)
98+
parser.add_argument(
99+
"--output",
100+
type=str,
101+
default="viz_torch_transforms.png",
102+
help="Output image file path (default: viz_torch_transforms.png)",
103+
)
104+
args = parser.parse_args()
105+
106+
nii_files = glob.glob(os.path.join(args.datasets_dir, "**", "*.nii.gz"), recursive=True)
107+
nii_files.sort()
108+
109+
single_sample = nii_files[1]
110+
im = nib.load(single_sample)
111+
data = im.get_fdata()
112+
113+
print("\nImage Analysis:")
114+
print("Data type:", data.dtype)
115+
print("Shape:", data.shape)
116+
print("Value range:", data.min(), "to", data.max())
117+
print("Mean value:", data.mean())
118+
print("Std value:", data.std())
119+
print("Unique values count:", len(np.unique(data)))
120+
print("First few unique values:", np.unique(data)[:10])
121+
122+
# Convert to float32 and scale to [0,1] while preserving relative intensities
123+
imarr = data.astype(np.float32)
124+
data_min = imarr.min()
125+
data_max = imarr.max()
126+
data_range = data_max - data_min
127+
128+
# Scale to [0,1] while preserving relative intensities
129+
imarr = (imarr - data_min) / data_range
130+
imarr = torch.from_numpy(imarr)
131+
# Assume batch and channel dimensions are present
132+
# imarr = imarr.unsqueeze(0).unsqueeze(0)
133+
134+
print("\nNormalized Image Analysis:")
135+
print("Value range:", imarr.min().item(), "to", imarr.max().item())
136+
print("Mean value:", imarr.mean().item())
137+
print("Std value:", imarr.std().item())
138+
139+
transforms = {
140+
"Original": imarr,
141+
"Bias Field": torch_bias_field(imarr.clone(), clip_to_input_range=True),
142+
"Blurred (σ=2.0)": torch_blur(imarr.clone(), sigma=2.0),
143+
"Gamma (0.5–2.0)": torch_gamma(imarr.clone(), gamma_range=(0.5, 2.0), clip_to_input_range=True),
144+
"Gamma (2.0)": torch_gamma(imarr, gamma_range=(2.0, 2.0), clip_to_input_range=True),
145+
"Ghost (α=2, ax=0)": torch_motion_ghosting(imarr.clone(), alpha=2.0, num_reps=4, axis=0, clip_to_input_range=True),
146+
"Ghost (α=2, ax=1)": torch_motion_ghosting(imarr.clone(), alpha=2.0, num_reps=4, axis=1, clip_to_input_range=True),
147+
"Masked (r=0.3)": torch_mask(imarr.clone(), pixel_value=0, ratio=0.3, token_size=[16, 16, 16]),
148+
"Add Noise (σ=0.05)": torch_additive_noise(imarr.clone(), mean=0, sigma=0.05, clip_to_input_range=True),
149+
"Add Noise (σ=0.1)": torch_additive_noise(imarr.clone(), mean=0, sigma=0.1, clip_to_input_range=True),
150+
"Add Noise (σ=0.2)": torch_additive_noise(imarr.clone(), mean=0, sigma=0.2, clip_to_input_range=True),
151+
"Mult Noise (σ=0.05)": torch_multiplicative_noise(imarr.clone(), mean=0, sigma=0.05, clip_to_input_range=True),
152+
"Mult Noise (σ=0.1)": torch_multiplicative_noise(imarr.clone(), mean=0, sigma=0.1, clip_to_input_range=True),
153+
"Mult Noise (σ=0.2)": torch_multiplicative_noise(imarr.clone(), mean=0, sigma=0.2, clip_to_input_range=True),
154+
"Gibbs (axes=0,1,2)": torch_gibbs_ringing(
155+
imarr.clone(), num_sample=64, axes=[0, 1, 2], mode="rect", clip_to_input_range=True
156+
),
157+
"Low Res": torch_simulate_lowres(imarr.clone(), target_shape=(32, 32, 32), clip_to_input_range=True),
158+
"Spatial": torch_spatial(
159+
imarr.clone(),
160+
patch_size=imarr.shape,
161+
p_deform=1.0,
162+
p_rot=1.0,
163+
p_rot_per_axis=1.0,
164+
p_scale=1.0,
165+
alpha=10.0,
166+
sigma=3.0,
167+
x_rot=0.5,
168+
y_rot=0,
169+
z_rot=0,
170+
scale_factor=1.0,
171+
clip_to_input_range=True,
172+
)[
173+
0
174+
], # Get only the image, not the label
175+
}
176+
177+
# Get middle slice for visualization
178+
slice_idx = imarr.shape[-1] // 2 # Changed to use last dimension
179+
180+
# Calculate grid dimensions
181+
n_transforms = len(transforms)
182+
n_cols = 6 # Show 6 images per row
183+
n_rows = (n_transforms + n_cols - 1) // n_cols # Ceiling division
184+
185+
# Create a figure with subplots
186+
fig_width = 2.2 * n_cols
187+
fig_height = 2.2 * n_rows
188+
plt.rcParams["figure.figsize"] = [fig_width, fig_height]
189+
plt.rcParams["figure.dpi"] = 200 # Keep high DPI for quality
190+
fig, axes = plt.subplots(n_rows, n_cols, squeeze=False)
191+
192+
# Plot each transformation
193+
for idx, (title, transformed_data) in enumerate(transforms.items()):
194+
row = idx // n_cols
195+
col = idx % n_cols
196+
197+
# Plot transformed image
198+
if isinstance(transformed_data, tuple):
199+
transformed_data = transformed_data[0]
200+
print(title, transformed_data.shape)
201+
transformed_plot = transformed_data.numpy()
202+
axes[row, col].imshow(transformed_plot[:, :, slice_idx], cmap="gray")
203+
axes[row, col].set_title(title, pad=1, fontsize=7)
204+
axes[row, col].axis("off")
205+
206+
# Hide any unused subplots
207+
for idx in range(len(transforms), n_rows * n_cols):
208+
row = idx // n_cols
209+
col = idx % n_cols
210+
axes[row, col].axis("off")
211+
axes[row, col].set_visible(False)
212+
213+
# Use subplots_adjust for tight packing
214+
plt.subplots_adjust(top=0.92, bottom=0.08, left=0.03, right=0.97, wspace=0.02, hspace=0.15)
215+
216+
# Set window title
217+
fig.canvas.manager.set_window_title("Image Transformations")
218+
219+
# Save the figure
220+
plt.savefig(args.output, dpi=200, bbox_inches="tight")
221+
print(f"\nSaved visualization to: {args.output}")
222+
223+
# Connect the click event
224+
print("\nConnecting click event handler...")
225+
fig.canvas.mpl_connect("button_press_event", on_click)
226+
print("Click event handler connected. Click any image to view in high resolution.")
227+
228+
# Only show the plot if not in headless mode
229+
if not args.headless:
230+
plt.show()

yucca/functional/preprocessing.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -296,10 +296,12 @@ def preprocess_case_for_training_with_label(
296296

297297
if final_target_size is not None:
298298
images, label = pad_case_to_size(case=images, size=final_target_size, label=label)
299-
image_properties["foreground_locations"], image_properties["label_cc_n"], image_properties["label_cc_sizes"] = (
300-
analyze_label(
301-
label=label, enable_connected_components_analysis=enable_cc_analysis, per_class=foreground_locs_per_label
302-
)
299+
(
300+
image_properties["foreground_locations"],
301+
image_properties["label_cc_n"],
302+
image_properties["label_cc_sizes"],
303+
) = analyze_label(
304+
label=label, enable_connected_components_analysis=enable_cc_analysis, per_class=foreground_locs_per_label
303305
)
304306

305307
first_existing_modality = list(set(range(len(images))).difference(missing_modality_idxs))[0]

yucca/functional/transforms/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,11 @@
88
from yucca.functional.transforms.masking import mask_batch
99
from yucca.functional.transforms.spatial import spatial
1010
from yucca.functional.transforms.skeleton import skeleton
11+
from yucca.functional.transforms.torch.blur import torch_blur
12+
from yucca.functional.transforms.torch.bias_field import torch_bias_field
13+
from yucca.functional.transforms.torch.gamma import torch_gamma
14+
from yucca.functional.transforms.torch.motion_ghosting import torch_motion_ghosting
15+
from yucca.functional.transforms.torch.noise import torch_additive_noise, torch_multiplicative_noise
16+
from yucca.functional.transforms.torch.ringing import torch_gibbs_ringing
17+
from yucca.functional.transforms.torch.sampling import torch_simulate_lowres
18+
from yucca.functional.transforms.torch.spatial import torch_spatial

yucca/functional/transforms/croppad.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ def croppad(
1212
label: np.ndarray = None,
1313
**pad_kwargs,
1414
):
15-
1615
if len(patch_size) == 3:
1716
image, label = croppad_3D_case_from_3D(
1817
image=image,

yucca/functional/transforms/spatial.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@ def spatial(
8585
# Mapping the images to the distorted coordinates
8686
for b in range(image.shape[0]):
8787
for c in range(image.shape[1]):
88-
8988
img_min = image.min()
9089
img_max = image.max()
9190

0 commit comments

Comments
 (0)