File tree Expand file tree Collapse file tree 1 file changed +5
-5
lines changed
Expand file tree Collapse file tree 1 file changed +5
-5
lines changed Original file line number Diff line number Diff line change @@ -36,11 +36,6 @@ def create_trainer(config: dict) -> "UNetTrainer":
3636 device = config .get ("device" , None )
3737 assert device , "Device not specified in the config file and could not be inferred automatically"
3838 logger .info (f"Using device: { device } " )
39-
40- # use DataParallel if more than 1 GPU available
41- if device == TorchDevice .CUDA and torch .cuda .device_count () > 1 :
42- model = nn .DataParallel (model )
43- logger .info (f"Using { torch .cuda .device_count ()} GPUs for training" )
4439 model .to (device )
4540
4641 # Log the number of learnable parameters
@@ -204,6 +199,11 @@ def __init__(
204199 if not self .checkpoint_dir :
205200 self .checkpoint_dir = os .path .split (pre_trained )[0 ]
206201
202+ # use DataParallel if more than 1 GPU available
203+ if device == TorchDevice .CUDA and torch .cuda .device_count () > 1 :
204+ self .model = nn .DataParallel (self .model )
205+ logger .info (f"Using { torch .cuda .device_count ()} GPUs for training" )
206+
207207 def fit (self ):
208208 for _ in range (self .num_epochs , self .max_num_epochs ):
209209 # train for one epoch
You can’t perform that action at this time.
0 commit comments