@@ -51,12 +51,12 @@ def load_filenames(data_list_path: str) -> list:
51
51
52
52
53
53
def prepare_data (
54
- train_files : list ,
55
- device : torch .device ,
56
- cache_rate : float ,
57
- num_workers : int = 2 ,
58
- batch_size : int = 1 ,
59
- include_body_region : bool = False
54
+ train_files : list ,
55
+ device : torch .device ,
56
+ cache_rate : float ,
57
+ num_workers : int = 2 ,
58
+ batch_size : int = 1 ,
59
+ include_body_region : bool = False ,
60
60
) -> DataLoader :
61
61
"""
62
62
Prepare training data.
@@ -78,11 +78,11 @@ def _load_data_from_file(file_path, key):
78
78
return torch .FloatTensor (json .load (f )[key ])
79
79
80
80
train_transforms_list = [
81
- monai .transforms .LoadImaged (keys = ["image" ]),
82
- monai .transforms .EnsureChannelFirstd (keys = ["image" ]),
83
- monai .transforms .Lambdad (keys = "spacing" , func = lambda x : _load_data_from_file (x , "spacing" )),
84
- monai .transforms .Lambdad (keys = "spacing" , func = lambda x : x * 1e2 ),
85
- ]
81
+ monai .transforms .LoadImaged (keys = ["image" ]),
82
+ monai .transforms .EnsureChannelFirstd (keys = ["image" ]),
83
+ monai .transforms .Lambdad (keys = "spacing" , func = lambda x : _load_data_from_file (x , "spacing" )),
84
+ monai .transforms .Lambdad (keys = "spacing" , func = lambda x : x * 1e2 ),
85
+ ]
86
86
if include_body_region :
87
87
train_transforms_list += [
88
88
monai .transforms .Lambdad (
@@ -202,7 +202,7 @@ def train_one_epoch(
202
202
logger : logging .Logger ,
203
203
local_rank : int ,
204
204
amp : bool = True ,
205
- include_body_region : bool = False
205
+ include_body_region : bool = False ,
206
206
) -> torch .Tensor :
207
207
"""
208
208
Train the model for one epoch.
@@ -284,9 +284,10 @@ def train_one_epoch(
284
284
# predict velocity
285
285
loss = loss_pt (model_output .float (), (images - noise ).float ())
286
286
else :
287
- raise ValueError ("noise scheduler prediction type has to be chosen from " ,
288
- f"[{ DDPMPredictionType .EPSILON } ,{ DDPMPredictionType .SAMPLE } ,{ DDPMPredictionType .V_PREDICTION } ]"
289
- )
287
+ raise ValueError (
288
+ "noise scheduler prediction type has to be chosen from " ,
289
+ f"[{ DDPMPredictionType .EPSILON } ,{ DDPMPredictionType .SAMPLE } ,{ DDPMPredictionType .V_PREDICTION } ]" ,
290
+ )
290
291
291
292
if amp :
292
293
scaler .scale (loss ).backward ()
@@ -349,7 +350,12 @@ def save_checkpoint(
349
350
350
351
351
352
def diff_model_train (
352
- env_config_path : str , model_config_path : str , model_def_path : str , num_gpus : int , amp : bool = True , include_body_region : bool = False
353
+ env_config_path : str ,
354
+ model_config_path : str ,
355
+ model_def_path : str ,
356
+ num_gpus : int ,
357
+ amp : bool = True ,
358
+ include_body_region : bool = False ,
353
359
) -> None :
354
360
"""
355
361
Main function to train a diffusion model.
@@ -400,9 +406,11 @@ def diff_model_train(
400
406
)[local_rank ]
401
407
402
408
train_loader = prepare_data (
403
- train_files , device , args .diffusion_unet_train ["cache_rate" ],
409
+ train_files ,
410
+ device ,
411
+ args .diffusion_unet_train ["cache_rate" ],
404
412
batch_size = args .diffusion_unet_train ["batch_size" ],
405
- include_body_region = include_body_region
413
+ include_body_region = include_body_region ,
406
414
)
407
415
408
416
unet = load_unet (args , device , logger )
@@ -438,7 +446,7 @@ def diff_model_train(
438
446
logger ,
439
447
local_rank ,
440
448
amp = amp ,
441
- include_body_region = include_body_region
449
+ include_body_region = include_body_region ,
442
450
)
443
451
444
452
loss_torch = loss_torch .tolist ()
@@ -479,7 +487,14 @@ def diff_model_train(
479
487
)
480
488
parser .add_argument ("--num_gpus" , type = int , default = 1 , help = "Number of GPUs to use for training" )
481
489
parser .add_argument ("--no_amp" , dest = "amp" , action = "store_false" , help = "Disable automatic mixed precision training" )
482
- parser .add_argument ("--include_body_region" , dest = "include_body_region" , action = "store_true" , help = "Whether to include body region in data" )
490
+ parser .add_argument (
491
+ "--include_body_region" ,
492
+ dest = "include_body_region" ,
493
+ action = "store_true" ,
494
+ help = "Whether to include body region in data" ,
495
+ )
483
496
484
497
args = parser .parse_args ()
485
- diff_model_train (args .env_config , args .model_config , args .model_def , args .num_gpus , args .amp , args .include_body_region )
498
+ diff_model_train (
499
+ args .env_config , args .model_config , args .model_def , args .num_gpus , args .amp , args .include_body_region
500
+ )
0 commit comments