@@ -375,24 +375,29 @@ def __init__(self, mode: str = 'spatial'):
375375
376376 self .mode = mode
377377
378- def _blur (self , std : float = 0.85 ) -> tio .RandomBlur :
379- return tio .RandomBlur (std = std )
378+ def _blur (self , p : float = 1 , std : float = 0.85 ) -> tio .RandomBlur :
379+ return tio .RandomBlur (std = std , p = p )
380380
381- def _bias (self , coefficients : float = 0.15 , order : int = 3 ) -> tio .RandomBiasField :
382- return tio .RandomBiasField (coefficients = coefficients , order = order )
381+ def _bias (self , p : float = 1 , coefficients : float = 0.15 , order : int = 3 ) -> tio .RandomBiasField :
382+ return tio .RandomBiasField (coefficients = coefficients , order = order , p = p )
383383
384- def _noise (self , mean : float = 0 , std : float = 0.008 ) -> tio .RandomNoise :
385- return tio .RandomNoise (mean = mean , std = std )
384+ def _noise (self , p : float = 1 , mean : float = 0 , std : float = 0.008 ) -> tio .RandomNoise :
385+ return tio .RandomNoise (mean = mean , std = std , p = p )
386386
387- def _flip (self , axes : Union [Tuple [int , ...], int ], probability : float = 1.0 ) -> tio .RandomFlip :
388- return tio .RandomFlip (axes = axes , flip_probability = probability )
387+ def _flip (self , p : float = 1 , axes : Union [Tuple [int , ...], int ] = (0 , 1 , 2 ), probability : float = 1.0 ) -> tio .RandomFlip :
388+ # Randomly choose a single axis from the available axes
389+ if isinstance (axes , tuple ):
390+ random_axis = int (np .random .choice (axes ))
391+ else :
392+ random_axis = axes
393+ return tio .RandomFlip (axes = random_axis , flip_probability = probability , p = p )
389394
390- def _elastic_deform (self , num_control_points : int = 9 ,
395+ def _elastic_deform (self , p : float = 1 , num_control_points : int = 9 ,
391396 max_displacement : int = 7 ,
392397 locked_borders : int = 2 ) -> tio .RandomElasticDeformation :
393398 return tio .RandomElasticDeformation (num_control_points = num_control_points ,
394399 max_displacement = max_displacement ,
395- locked_borders = locked_borders )
400+ locked_borders = locked_borders , p = p )
396401
397402 def __call__ (self , image_batch : torch .Tensor , seg_batch : torch .Tensor ) -> Tuple [torch .Tensor , torch .Tensor ]:
398403 """
@@ -416,36 +421,42 @@ def __call__(self, image_batch: torch.Tensor, seg_batch: torch.Tensor) -> Tuple[
416421 self ._blur (),
417422 self ._bias (),
418423 self ._noise (),
419- self ._flip (axes = (0 , 1 )),
420- self ._elastic_deform (),
421- ])
422- elif self .mode == 'random' :
423- transforms = tio .OneOf ([
424- self ._blur (),
425- self ._bias (),
426- self ._noise (),
427- self ._flip (axes = (0 , 1 )),
424+ self ._flip (axes = (0 , 1 , 2 )),
428425 self ._elastic_deform ()
429426 ])
427+ elif self .mode == 'random' :
428+ transforms = tio .OneOf ({
429+ self ._blur () : 0.1 ,
430+ self ._bias () : 0.1 ,
431+ self ._noise () : 0.1 ,
432+ self ._flip (axes = (0 , 1 , 2 )) : 0.35 ,
433+ self ._elastic_deform () : 0.35
434+ })
430435 elif self .mode == 'spatial' :
431436 transforms = tio .Compose ([
432- self ._flip (axes = (0 , 1 )),
437+ self ._flip (axes = (0 , 1 , 2 )),
433438 self ._elastic_deform ()
434439 ])
435440 elif self .mode == 'intensity' :
436- transforms = tio .Compose ([
441+ transforms = tio .OneOf ([
437442 self ._blur (),
438443 self ._bias (),
439444 self ._noise ()
440445 ])
441446 elif self .mode == 'off' :
442447 # No augmentation, return original subject
443- return subject_batch ['image' ].data , subject_batch ['label' ].data # type: ignore
448+ return subject_batch ['image' ].data . unsqueeze ( 1 ) , subject_batch ['label' ].data . unsqueeze ( 1 ) # type: ignore
444449 else :
445450 raise ValueError (f"Unsupported mode '{ self .mode } ' for TorchIO augmentations" )
446451
447452 # Apply the transform to the subject batch
448453 transformed_subject = transforms (subject_batch )
454+
455+ # # Track which transform was applied (for OneOf modes)
456+ # # TESTING BLOCK, UNCOMMENT THIS FOR DEBUGGING
457+ # if self.mode in ['random', 'intensity']:
458+ # applied_transforms = [str(transform) for transform in transformed_subject.history]
459+ # print(f"Applied transforms in {self.mode} mode: {applied_transforms[-1] if applied_transforms else 'None'}")
449460
450461 # Extract image and label tensors
451462 image_tensor = transformed_subject ['image' ].data .unsqueeze (1 ) # type: ignore
0 commit comments