Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
9c27332
Make CLS pipeline
Sllambias Oct 30, 2024
9a37728
remove concept of task_type in LM
Sllambias Nov 5, 2024
2d2cb0d
remove print statement and fix imports
Sllambias Nov 5, 2024
060a4bd
updates
Sllambias Nov 5, 2024
57616ef
add ResNet18
Sllambias Nov 14, 2024
b589527
updates
Sllambias Nov 29, 2024
c775543
fix cls eval
Sllambias Nov 29, 2024
d8171e3
update toml
Sllambias Nov 29, 2024
0fd01e2
CLS updates
Sllambias Dec 10, 2024
4de49cc
logging
Sllambias Dec 16, 2024
ba9499a
add 30 and 31
Sllambias Dec 23, 2024
5611144
add cls setups
Sllambias Jan 2, 2025
413d780
new managers and option to incl segmentation in trianing
Sllambias Jan 8, 2025
8590b4f
bugfix
Sllambias Jan 8, 2025
37c9558
another one
Sllambias Jan 8, 2025
ba1ac57
add predict
Sllambias Jan 15, 2025
59950a7
.
Sllambias Jan 15, 2025
dcdb1d3
.
Sllambias Jan 15, 2025
ba95a8a
new vers
Sllambias Jan 22, 2025
fab0960
cleanup on isle1
Sllambias Jan 29, 2025
95e00b9
add support for covariates
Sllambias Feb 19, 2025
c7b4022
updates
Sllambias Feb 19, 2025
abb9e96
finish covariates implementation
Sllambias Feb 20, 2025
9141f59
upd
Sllambias Feb 21, 2025
4a46f58
add DenseNet
Sllambias Feb 24, 2025
4ad3b3c
add densenet manager
Sllambias Feb 24, 2025
2b3bd8b
fix typo
Sllambias Feb 24, 2025
9e9f7fe
.
Sllambias Feb 24, 2025
97d46f6
.
Sllambias Feb 24, 2025
9122120
,,,
Sllambias Feb 24, 2025
2dfeb78
add v10
Sllambias Feb 25, 2025
d11d7aa
add test covariate
Sllambias Mar 20, 2025
19e5194
fix typo
Sllambias Mar 20, 2025
03c5fd3
RE fixes
Sllambias Mar 20, 2025
8605213
cov predict
Sllambias Mar 20, 2025
56db633
.
Sllambias Mar 20, 2025
ae4d08f
wrap in arr
Sllambias Mar 20, 2025
4fbd0c1
wrap in torch instead
Sllambias Mar 20, 2025
88de858
add unsqueeze
Sllambias Mar 20, 2025
b229a1b
add resnet 34
Sllambias Mar 20, 2025
d5b52a2
add resnet50
Sllambias Mar 20, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "yucca"
version = "2.2.1"
version = "2.2.4"
authors = [
{ name="Sebastian Llambias", email="llambias@live.com" },
{ name="Asbjørn Munk", email="9844416+asbjrnmunk@users.noreply.github.com" },
Expand Down Expand Up @@ -29,6 +29,7 @@ dependencies = [
"numpy>=1.26.4",
"pandas>=2.2.1",
"python-dotenv==1.0.0",
"pytorchvideo==0.1.5",
"scikit_image>=0.22.0",
"scikit_learn>=1.4.1.post1",
"seaborn>=0.13.2",
Expand Down
9 changes: 9 additions & 0 deletions yucca/functional/evaluation/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,12 @@ def torch_get_tp_fp_tn_fn(confusion_matrix, ignore_label=0):
TN.append(tn.cpu().numpy())
FN.append(fn.cpu().numpy())
return TP, FP, TN, FN


def convert_confusion_matrix_to_dict(confusion_matrix):
d = {}
for true_label, row in enumerate(confusion_matrix):
d[str(true_label)] = {}
for predicted_label, value in enumerate(row):
d[str(true_label)][str(predicted_label)] = value
return d
17 changes: 6 additions & 11 deletions yucca/functional/evaluation/evaluate_folder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import sys
import os
import numpy as np
import nibabel as nib
import logging
Expand All @@ -11,6 +12,7 @@
from batchgenerators.utilities.file_and_folder_operations import join
from sklearn.metrics import confusion_matrix
from yucca.functional.evaluation.metrics import auroc
from yucca.functional.evaluation.confusion_matrix import convert_confusion_matrix_to_dict


def evaluate_folder_segm(
Expand Down Expand Up @@ -257,9 +259,6 @@ def evaluate_folder_cls(
prediction_probs = []
ground_truths = []

# Flag to check if we have prediction probabilities to calculate AUROC
use_probs = False

# load predictions and ground truths
for case in tqdm(subjects, desc="Evaluating"):
predpath = join(folder_with_predictions, case)
Expand All @@ -268,15 +267,9 @@ def evaluate_folder_cls(
pred: int = np.loadtxt(predpath)
gt: int = np.loadtxt(gtpath)

try:
if len(prediction_probs) == 0:
print("Prediction probabilities found. Will use them for evaluation.")
use_probs = True

if os.path.isfile(predpath.replace(".txt", ".npz")):
pred_probs = np.load(predpath.replace(".txt", ".npz"))["data"] # contains output probabilities
prediction_probs.append(pred_probs)
except FileNotFoundError:
pred_probs = None

predictions.append(pred)
ground_truths.append(gt)
Expand All @@ -287,6 +280,8 @@ def evaluate_folder_cls(
# calculate per-class metrics
cmat = confusion_matrix(ground_truths, predictions, labels=labels)

cmat_dict = convert_confusion_matrix_to_dict(cmat)
resultdict["confusion_matrix"] = cmat_dict
resultdict["per_class"] = {}

for label in labels:
Expand All @@ -303,7 +298,7 @@ def evaluate_folder_cls(
resultdict["per_class"][str(label)] = labeldict

# calculate AUROC
if use_probs:
if len(prediction_probs) > 0:
auroc_per_class: list[float] = auroc(ground_truths, prediction_probs)
for label, score in zip(labels, auroc_per_class):
resultdict["per_class"][str(label)]["AUROC"] = round(score, 4)
Expand Down
5 changes: 4 additions & 1 deletion yucca/functional/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,7 @@ def preprocess_case_for_inference(
target_size,
target_spacing,
target_orientation,
background_pixel_value: int = 0,
allow_missing_modalities: bool = False,
ext=".nii.gz",
keep_aspect_ratio: bool = True,
Expand All @@ -425,7 +426,9 @@ def preprocess_case_for_inference(
image_properties["uncropped_shape"] = np.array(images[0].shape)

if crop_to_nonzero:
nonzero_box = get_bbox_for_foreground(images[0], background_label=0)
if np.max(images[0]) <= background_pixel_value:
background_pixel_value = np.min(images[0])
nonzero_box = get_bbox_for_foreground(images[0], background_label=background_pixel_value)
for i in range(len(images)):
images[i] = crop_to_box(images[i], nonzero_box)
image_properties["nonzero_box"] = nonzero_box
Expand Down
2 changes: 2 additions & 0 deletions yucca/functional/testing/data/nifti.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import nibabel as nib
import numpy as np
import nibabel.orientations as nio
import logging
from yucca.functional.utils.nib_utils import get_nib_orientation, get_nib_spacing


Expand Down Expand Up @@ -48,4 +49,5 @@ def verify_orientation_is_LR_PA_IS(image: nib.Nifti1Image):
if np.all(nio.axcodes2ornt(orientation)[:, 0] == expected_orientation_code):
return True
else:
logging.info(f"Found orientation {orientation}")
return False
28 changes: 15 additions & 13 deletions yucca/functional/transforms/gamma.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ def augment_gamma(
invert_image=False,
epsilon=1e-7,
per_channel=False,
p_per_channel=None,
clip_to_input_range=False,
):
if invert_image:
Expand All @@ -27,19 +28,20 @@ def augment_gamma(
data_sample = np.clip(data_sample, a_min=img_min, a_max=img_max)
else:
for c in range(data_sample.shape[0]):
if np.random.random() < 0.5 and gamma_range[0] < 1:
gamma = np.random.uniform(gamma_range[0], 1)
else:
gamma = np.random.uniform(max(gamma_range[0], 1), gamma_range[1])
img_min = data_sample[c].min()
img_max = data_sample[c].max()
img_range = img_max - img_min
data_sample[c] = (
np.power(((data_sample[c] - img_min) / float(img_range + epsilon)), gamma) * float(img_range + epsilon)
+ img_min
)
if clip_to_input_range:
data_sample[c] = np.clip(data_sample[c], a_min=img_min, a_max=img_max)
if np.random.uniform() < p_per_channel[c]:
if np.random.random() < 0.5 and gamma_range[0] < 1:
gamma = np.random.uniform(gamma_range[0], 1)
else:
gamma = np.random.uniform(max(gamma_range[0], 1), gamma_range[1])
img_min = data_sample[c].min()
img_max = data_sample[c].max()
img_range = img_max - img_min
data_sample[c] = (
np.power(((data_sample[c] - img_min) / float(img_range + epsilon)), gamma) * float(img_range + epsilon)
+ img_min
)
if clip_to_input_range:
data_sample[c] = np.clip(data_sample[c], a_min=img_min, a_max=img_max)
if invert_image:
data_sample = -data_sample
return data_sample
8 changes: 6 additions & 2 deletions yucca/functional/transforms/spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def spatial(
scale_factor,
clip_to_input_range,
label: Optional[np.ndarray] = None,
linear_interpolation_channel=None,
skip_label: bool = False,
do_crop: bool = True,
random_crop: bool = True,
Expand All @@ -40,6 +41,9 @@ def spatial(
cval = cval
assert isinstance(cval, (int, float)), f"got {cval} of type {type(cval)}"

if isinstance(order, (int, float)):
order = [order for _ in range(image.shape[1])]

coords = create_zero_centered_coordinate_matrix(patch_size)
image_canvas = np.zeros((image.shape[0], image.shape[1], *patch_size), dtype=np.float32)

Expand Down Expand Up @@ -92,7 +96,7 @@ def spatial(
image_canvas[b, c] = map_coordinates(
image[b, c].astype(float),
coords,
order=order,
order=order[c],
mode="constant",
cval=cval,
).astype(image.dtype)
Expand All @@ -106,7 +110,7 @@ def spatial(
dtype=np.float32,
)

# Mapping the labelmentations to the distorted coordinates
# Mapping the label to the distorted coordinates
for b in range(label.shape[0]):
for c in range(label.shape[1]):
label_canvas[b, c] = map_coordinates(label[b, c], coords, order=0, mode="constant", cval=0.0).astype(
Expand Down
4 changes: 3 additions & 1 deletion yucca/functional/visualization/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from yucca.functional.visualization.imshow import get_train_fig_with_inp_out_tar
from yucca.functional.visualization.imshow import get_segm_train_fig_with_inp_out_tar
from yucca.functional.visualization.imshow import get_cls_train_fig_with_inp_out_tar
from yucca.functional.visualization.imshow import get_ssl_train_fig_with_inp_out_tar
97 changes: 65 additions & 32 deletions yucca/functional/visualization/imshow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,14 @@
import matplotlib.pyplot as plt


def get_train_fig_with_inp_out_tar(input, output, target, fig_title, task_type: str = "segmentation"):
def get_segm_train_fig_with_inp_out_tar(input, output, target, fig_title):
# This needs to handle the following cases:
# Segmentation : {"input": (m,x,y(,z)), "target": (1,x,y(,z)), "output": (c,x,y(,z))}
# Self-supervised : {"input": (m,x,y(,z)), "target": (m,x,y(,z)), "output": (m,x,y(,z))}
# Classification : {"input": (m,x,y(,z)), "target": (1,x), "output": (c,x)}

channel_idx = np.random.randint(0, input.shape[0])

if len(input.shape) == 4: # 3D images.
# We need to select a slice to visualize.
if task_type == "segmentation" and len(target[0].nonzero()[0]) > 0:
# Select a foreground slice if any exist.
if len(input.shape) == 4:
if len(target[0].nonzero()[0]) > 0:
foreground_locations = target[0].nonzero()
slice_to_visualize = foreground_locations[0][np.random.randint(0, len(foreground_locations[0]))]
else:
Expand All @@ -27,30 +23,67 @@ def get_train_fig_with_inp_out_tar(input, output, target, fig_title, task_type:
output = output[:, slice_to_visualize]

image = input[channel_idx]
target = target[0]
output = output.argmax(0)

fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(15, 5), dpi=100, constrained_layout=True)
axes[0].imshow(image, cmap="gray", vmin=np.quantile(image, 0.01), vmax=np.quantile(image, 0.99))
axes[0].set_title("input")
axes[1].imshow(target, cmap="gray")
axes[1].set_title("target")
axes[2].imshow(output, cmap="gray")
axes[2].set_title("output")
fig.suptitle(fig_title, fontsize=16)
return fig


def get_cls_train_fig_with_inp_out_tar(input, output, target, fig_title):
# This needs to handle the following case:
# Classification : {"input": (m,x,y(,z)), "target": (n_classes), "output": (n_classes)}

channel_idx = np.random.randint(0, input.shape[0])

slice_to_visualize = np.random.randint(0, input.shape[1])

if len(input.shape) == 4: # 3D images.
input = input[:, slice_to_visualize]

image = input[channel_idx]

output = output.argmax(0)

fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(5, 5), dpi=100, constrained_layout=True)
axes.imshow(image, cmap="gray", vmin=np.quantile(image, 0.01), vmax=np.quantile(image, 0.99))
axes.set_title(f"Input: {fig_title}", fontsize=12)
fig.suptitle(f"Target: {target} | Output: {output}", fontsize=12)
return fig


def get_ssl_train_fig_with_inp_out_tar(input, output, target, fig_title):
# This needs to handle the following cases:
# Self-supervised : {"input": (m,x,y(,z)), "target": (m,x,y(,z)), "output": (m,x,y(,z))}

channel_idx = np.random.randint(0, input.shape[0])

if len(input.shape) == 4: # 3D images.
slice_to_visualize = np.random.randint(0, input.shape[1])
input = input[:, slice_to_visualize]
if len(target.shape) == 4:
target = target[:, slice_to_visualize]
if len(output.shape) == 4:
output = output[:, slice_to_visualize]

image = input[channel_idx]

target = target[channel_idx]
output = output[channel_idx]

if task_type in ["segmentation", "classification"]:
target = target[0]
output = output.argmax(0)
elif task_type == "self-supervised":
target = target[channel_idx]
output = output[channel_idx]
else:
logging.warn(
f"Unknown task type. Found {task_type} and expected one in ['classification', 'segmentation', 'self-supervised']"
)

if len(target.shape) == 1:
fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(5, 5), dpi=100, constrained_layout=True)
axes[0].imshow(image, cmap="gray", vmin=np.quantile(image, 0.01), vmax=np.quantile(image, 0.99))
axes[0].set_title("input")
fig.suptitle(f"{fig_title}. Target: {target} | Output: {output}", fontsize=16)
else:
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(15, 5), dpi=100, constrained_layout=True)
axes[0].imshow(image, cmap="gray", vmin=np.quantile(image, 0.01), vmax=np.quantile(image, 0.99))
axes[0].set_title("input")
axes[1].imshow(target, cmap="gray")
axes[1].set_title("target")
axes[2].imshow(output, cmap="gray")
axes[2].set_title("output")
fig.suptitle(fig_title, fontsize=16)
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(15, 5), dpi=100, constrained_layout=True)
axes[0].imshow(image, cmap="gray", vmin=np.quantile(image, 0.01), vmax=np.quantile(image, 0.99))
axes[0].set_title("input")
axes[1].imshow(target, cmap="gray")
axes[1].set_title("target")
axes[2].imshow(output, cmap="gray")
axes[2].set_title("output")
fig.suptitle(fig_title, fontsize=16)
return fig
2 changes: 1 addition & 1 deletion yucca/modules/callbacks/loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def create_logfile(self):
self.log_dir,
"training_log.txt",
)
with open(self.log_file, "w") as f:
with open(self.log_file, "a+") as f:
f.write("Starting model training")
logging.info("Starting model training \n" f'{"log file:":20} {self.log_file} \n')
f.write("\n")
Expand Down
Loading