Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
AsDiscreted,
EnsureTyped,
EnsureChannelFirstd,
EnsureChannelFirst,
Compose,
CropForegroundd,
LoadImaged,
Expand Down Expand Up @@ -137,6 +138,59 @@ def wb_mask(bg_img, mask):
for image_name, label_name in zip(train_images_match, train_labels_match)]
# TODO: add check if data empty

# << Trying suggestion: https://github.com/Project-MONAI/MONAI/discussions/5948#discussioncomment-5330042
# Split train/val
train_files, val_files = data_dicts[:-1], data_dicts[-1:]
# Set deterministic training for reproducibility
set_determinism(seed=0)
# # Volume-level transforms for both image and segmentation
# train_transforms = Compose(
# [
# LoadImaged(keys=["image", "label"]),
# EnsureChannelFirstd(keys=["image", "label"]),
# ScaleIntensityd(keys="image"),
# # RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[0, 1]),
# EnsureTyped(keys=["image", "label"]),
# ]
# )
# # 3D dataset with preprocessing transforms
# volume_ds = CacheDataset(data=train_files, transform=train_transforms)
n_samples = 5
sampler = Compose([
LoadImaged(keys=["image", "label"], image_only=True),
EnsureChannelFirstd(keys=["image", "label"]),
RandCropByPosNegLabeld(
keys=["image", "label"],
image_key="image",
label_key="label",
spatial_size=(200, 200, 1),
pos=1,
neg=0,
num_samples=n_samples,
),
])

ds = PatchDataset(data=train_files,
patch_func=sampler,
samples_per_image=n_samples,
transform=None)
check_loader = DataLoader(ds, batch_size=1)
check_data = first(check_loader)

print("First volume's shape: ", check_data["image"].shape, check_data["label"].shape)
i=0
for check_data in check_loader:
if 'image' in check_data:
print(f"{i}: {check_data['image'].size()}")
i += 1

# IGNORE CODE BELOW-- DEBUGGING






# Iterate across image/label 3D volume, fetch non-empty slice and output a single list of image/label pair
patch_data = []
for data_dict in data_dicts:
Expand Down