We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 5d1f3a7 commit 5a023dcCopy full SHA for 5a023dc
1 file changed
src/ml.py
@@ -474,7 +474,7 @@ def load_checkpoint(self):
474
map_location={"cuda:0": "mps", "cuda": "mps"},
475
)
476
else:
477
- checkpoint = torch.load(self.checkpoint + "/model.pth")
+ checkpoint = torch.load(self.checkpoint + "/model.pth",map_location=f"cuda:{local_rank}")
478
embedding_weights = {
479
k: v
480
for k, v in checkpoint.items()
0 commit comments