Skip to content

Commit c7a112d

Browse files
authored
Merge pull request #226 from KMarshallX/pre-release-2_0-0910
Pre release 2 0 0910
2 parents 0559645 + 2243130 commit c7a112d

File tree

11 files changed

+156
-34
lines changed

11 files changed

+156
-34
lines changed

config/adapt_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
adapt_parser.add_argument('--patch_size', nargs=3, type=int, default=(64, 64, 64),
3737
help='Expected size for training (x y z)')
3838
# optimizer type, available: [sgd, adam]
39-
adapt_parser.add_argument('--optimizer', type=str, default="adam", help='available: [sgd, adam]')
39+
adapt_parser.add_argument('--optimizer', type=str, default="adam", help='available: [sgd, adam, adamw]')
4040
# loss metric type, available: [bce, dice, tver]
4141
adapt_parser.add_argument('--loss_metric', type=str, default="tver", help="available: [bce, dice, tver]")
4242

config/angiboost_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
# expected size of the training patch
3232
angiboost_parser.add_argument('--osz', nargs=3, type=int, default=(64, 64, 64),
3333
help='Expected size of the training patch (x y z)')
34-
angiboost_parser.add_argument('--optimizer', type=str, default="adam", help="available: [sgd, adam]")
34+
angiboost_parser.add_argument('--optimizer', type=str, default="adam", help="available: [sgd, adam, adamw]")
3535
angiboost_parser.add_argument('--loss_metric', type=str, default="tver", help="available: [bce, dice, tver]")
3636

3737
# Optimizer tuning

config/boost_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
# expected size of the training patch
3131
boost_parser.add_argument('--patch_size', nargs=3, type=int, default=(64, 64, 64),
3232
help='Expected size of the training patch (x y z)')
33-
boost_parser.add_argument('--optimizer', type=str, default="adam", help="available: [sgd, adam]")
33+
boost_parser.add_argument('--optimizer', type=str, default="adam", help="available: [sgd, adam, adamw]")
3434
boost_parser.add_argument('--loss_metric', type=str, default="tver", help="available: [bce, dice, tver]")
3535

3636
# Optimizer tuning

config/train_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
# expected patch size for training
3030
train_parser.add_argument('--output_size', nargs=3, type=int, default=(64, 64, 64), help='Expected patch size for training (x y z)')
3131
# optimizer type, available: [sgd, adam]
32-
train_parser.add_argument('--optimizer', type=str, default="adam", help='available: [sgd, adam]')
32+
train_parser.add_argument('--optimizer', type=str, default="adam", help='available: [sgd, adam, adamw]')
3333
# loss metric type, available: [bce, dice, tver]
3434
train_parser.add_argument('--loss_metric', type=str, default="tver", help="available: [bce, dice, tver]")
3535

environment.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,12 @@ dependencies:
1010
- pytorch-cuda=12.4
1111
- python=3.10
1212
- pip
13+
- jupyter
14+
- notebook
15+
- jupyterlab
16+
- ipython
17+
- ipykernel
18+
- ipywidgets
1319
- pip:
1420
- numpy==1.26.4
1521
- scipy==1.15.2
@@ -23,3 +29,4 @@ dependencies:
2329
- antspyx==0.4.2
2430
- connected-components-3d==3.13.0
2531
- osfclient==0.0.5
32+
- jupyterlmod==5.3.0

library/aug_utils.py

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -375,24 +375,29 @@ def __init__(self, mode: str = 'spatial'):
375375

376376
self.mode = mode
377377

378-
def _blur(self, std: float = 0.85) -> tio.RandomBlur:
379-
return tio.RandomBlur(std=std)
378+
def _blur(self, p: float = 1, std: float = 0.85) -> tio.RandomBlur:
379+
return tio.RandomBlur(std=std, p=p)
380380

381-
def _bias(self, coefficients: float = 0.15, order: int = 3) -> tio.RandomBiasField:
382-
return tio.RandomBiasField(coefficients=coefficients, order=order)
381+
def _bias(self, p: float = 1, coefficients: float = 0.15, order: int = 3) -> tio.RandomBiasField:
382+
return tio.RandomBiasField(coefficients=coefficients, order=order, p=p)
383383

384-
def _noise(self, mean: float = 0, std: float = 0.008) -> tio.RandomNoise:
385-
return tio.RandomNoise(mean=mean, std=std)
384+
def _noise(self, p: float = 1, mean: float = 0, std: float = 0.008) -> tio.RandomNoise:
385+
return tio.RandomNoise(mean=mean, std=std, p=p)
386386

387-
def _flip(self, axes: Union[Tuple[int, ...], int], probability: float = 1.0) -> tio.RandomFlip:
388-
return tio.RandomFlip(axes=axes, flip_probability=probability)
387+
def _flip(self, p: float = 1, axes: Union[Tuple[int, ...], int] = (0, 1, 2), probability: float = 1.0) -> tio.RandomFlip:
388+
# Randomly choose a single axis from the available axes
389+
if isinstance(axes, tuple):
390+
random_axis = int(np.random.choice(axes))
391+
else:
392+
random_axis = axes
393+
return tio.RandomFlip(axes=random_axis, flip_probability=probability, p=p)
389394

390-
def _elastic_deform(self, num_control_points: int = 9,
395+
def _elastic_deform(self, p: float = 1, num_control_points: int = 9,
391396
max_displacement: int = 7,
392397
locked_borders: int = 2) -> tio.RandomElasticDeformation:
393398
return tio.RandomElasticDeformation(num_control_points=num_control_points,
394399
max_displacement=max_displacement,
395-
locked_borders=locked_borders)
400+
locked_borders=locked_borders, p=p)
396401

397402
def __call__(self, image_batch: torch.Tensor, seg_batch: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
398403
"""
@@ -416,36 +421,42 @@ def __call__(self, image_batch: torch.Tensor, seg_batch: torch.Tensor) -> Tuple[
416421
self._blur(),
417422
self._bias(),
418423
self._noise(),
419-
self._flip(axes=(0, 1)),
420-
self._elastic_deform(),
421-
])
422-
elif self.mode == 'random':
423-
transforms = tio.OneOf([
424-
self._blur(),
425-
self._bias(),
426-
self._noise(),
427-
self._flip(axes=(0, 1)),
424+
self._flip(axes=(0, 1, 2)),
428425
self._elastic_deform()
429426
])
427+
elif self.mode == 'random':
428+
transforms = tio.OneOf({
429+
self._blur() : 0.1,
430+
self._bias() : 0.1,
431+
self._noise() : 0.1,
432+
self._flip(axes=(0, 1, 2)) : 0.35,
433+
self._elastic_deform() : 0.35
434+
})
430435
elif self.mode == 'spatial':
431436
transforms = tio.Compose([
432-
self._flip(axes=(0, 1)),
437+
self._flip(axes=(0, 1, 2)),
433438
self._elastic_deform()
434439
])
435440
elif self.mode == 'intensity':
436-
transforms = tio.Compose([
441+
transforms = tio.OneOf([
437442
self._blur(),
438443
self._bias(),
439444
self._noise()
440445
])
441446
elif self.mode == 'off':
442447
# No augmentation, return original subject
443-
return subject_batch['image'].data, subject_batch['label'].data # type: ignore
448+
return subject_batch['image'].data.unsqueeze(1), subject_batch['label'].data.unsqueeze(1) # type: ignore
444449
else:
445450
raise ValueError(f"Unsupported mode '{self.mode}' for TorchIO augmentations")
446451

447452
# Apply the transform to the subject batch
448453
transformed_subject = transforms(subject_batch)
454+
455+
# # Track which transform was applied (for OneOf modes)
456+
# # TESTING BLOCK, UNCOMMENT THIS FOR DEBUGGING
457+
# if self.mode in ['random', 'intensity']:
458+
# applied_transforms = [str(transform) for transform in transformed_subject.history]
459+
# print(f"Applied transforms in {self.mode} mode: {applied_transforms[-1] if applied_transforms else 'None'}")
449460

450461
# Extract image and label tensors
451462
image_tensor = transformed_subject['image'].data.unsqueeze(1) # type: ignore

library/data_loaders.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,9 @@ def _zooming(self, img_crop: np.ndarray, seg_crop: np.ndarray) -> Tuple[np.ndarr
114114
"""
115115
if img_crop.shape != self.patch_size:
116116
zoom_factors = tuple(float(out_dim) / float(crop_dim) for out_dim, crop_dim in zip(self.patch_size, img_crop.shape))
117-
img_crop = scind.zoom(img_crop, zoom_factors, order=3, mode='nearest')
117+
img_crop = scind.zoom(img_crop, zoom_factors, order=0, mode='nearest')
118118
seg_crop = scind.zoom(seg_crop, zoom_factors, order=0, mode='nearest')
119-
return img_crop, seg_crop
119+
return img_crop.astype(np.float32), seg_crop.astype(np.int8)
120120

121121
def __repr__(self) -> str:
122122
return (f"SingleChannelLoader(\n"
@@ -164,7 +164,7 @@ def __iter__(self):
164164
img_batch[self.batch_multiplier] = large_img_crop
165165
seg_batch[self.batch_multiplier] = large_seg_crop
166166
# Yield the cropped and resized patches as a batch of pytorch tensors
167-
yield torch.from_numpy(img_batch).float(), torch.from_numpy(seg_batch).long()
167+
yield torch.from_numpy(img_batch).float(), torch.from_numpy(seg_batch).ceil().int()
168168

169169
except Exception as e:
170170
logger.error(f"Error generating patch {i}: {e}")

library/loss_func.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,8 @@ def choose_optimizer(optim_name: str, model_params: Any, lr: float) -> Any:
150150
"""
151151
optimizer_registry = {
152152
'sgd': lambda: torch.optim.SGD(model_params, lr),
153-
'adam': lambda: torch.optim.Adam(model_params, lr)
153+
'adam': lambda: torch.optim.Adam(model_params, lr),
154+
'adamw': lambda: torch.optim.AdamW(model_params, lr)
154155
}
155156

156157
if optim_name not in optimizer_registry:

library/module_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ def process_single_image(
399399
resized_image = self._resize_image(image_array, target_size)
400400

401401
# Standardize image
402-
standardized_image = standardiser(resized_image)
402+
standardized_image = normaliser(resized_image)
403403

404404
# Create patches
405405
patches = patchify(standardized_image, (64, 64, 64), 64)
@@ -431,7 +431,7 @@ def predict_all_images(
431431
threshold: float,
432432
connect_threshold: int,
433433
save_mip: bool = False,
434-
save_probability: bool = False
434+
save_probability: bool = True
435435
) -> None:
436436
"""
437437
Process all images in the input directory.

library/train_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def _train_epoch(
232232
total_lr += optimizer.param_groups[0]['lr']
233233
num_batches += 1
234234

235-
logger.info(f"Data loading time: {data_loading_time:.2f}s, Model training time: {model_training_time:.2f}s") # TEST
235+
logger.info(f"\nData loading time: {data_loading_time:.2f}s, Model training time: {model_training_time:.2f}s") # TEST
236236
return total_loss / num_batches, total_lr / num_batches
237237

238238
def train_model(
@@ -269,7 +269,7 @@ def train_model(
269269
# Log progress
270270
tqdm.write(
271271
f'Epoch [{epoch+1}/{self.num_epochs}], '
272-
f'Loss: {avg_loss:.4f}, LR: {avg_lr:.8f}'
272+
f'Loss: {avg_loss:.8f}, LR: {avg_lr:.8f}'
273273
)
274274

275275
# Save model

0 commit comments

Comments
 (0)