We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent c69eb47 commit a5f28f5Copy full SHA for a5f28f5
1 file changed
src/ml.py
@@ -488,7 +488,7 @@ def load_checkpoint(self):
488
self.model.eval()
489
490
491
- self.model = self.model.cuda(self.local_rank)
+ self.model = self.model.cuda(f"cuda:{self.local_rank}")
492
self.model = nn.parallel.DistributedDataParallel(self.model,
493
device_ids=[self.local_rank],
494
output_device=self.local_rank,
0 commit comments