Skip to content

Commit 231b80f

Browse files
authored
Merge pull request #134 from JCasaraconn/fix-resume-with-dataparallel
fix resume from checkpoint when using multiple GPUs
2 parents 71bf82b + a3c5197 commit 231b80f

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

pytorch3dunet/unet3d/trainer.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,6 @@ def create_trainer(config: dict) -> "UNetTrainer":
3636
device = config.get("device", None)
3737
assert device, "Device not specified in the config file and could not be inferred automatically"
3838
logger.info(f"Using device: {device}")
39-
40-
# use DataParallel if more than 1 GPU available
41-
if device == TorchDevice.CUDA and torch.cuda.device_count() > 1:
42-
model = nn.DataParallel(model)
43-
logger.info(f"Using {torch.cuda.device_count()} GPUs for training")
4439
model.to(device)
4540

4641
# Log the number of learnable parameters
@@ -204,6 +199,11 @@ def __init__(
204199
if not self.checkpoint_dir:
205200
self.checkpoint_dir = os.path.split(pre_trained)[0]
206201

202+
# use DataParallel if more than 1 GPU available
203+
if device == TorchDevice.CUDA and torch.cuda.device_count() > 1:
204+
self.model = nn.DataParallel(self.model)
205+
logger.info(f"Using {torch.cuda.device_count()} GPUs for training")
206+
207207
def fit(self):
208208
for _ in range(self.num_epochs, self.max_num_epochs):
209209
# train for one epoch

0 commit comments

Comments
 (0)