Skip to content

Commit 5a023dc

Browse files
committed
distribute
1 parent 5d1f3a7 commit 5a023dc

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

src/ml.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,7 @@ def load_checkpoint(self):
474474
map_location={"cuda:0": "mps", "cuda": "mps"},
475475
)
476476
else:
477-
checkpoint = torch.load(self.checkpoint + "/model.pth")
477+
checkpoint = torch.load(self.checkpoint + "/model.pth",map_location=f"cuda:{local_rank}")
478478
embedding_weights = {
479479
k: v
480480
for k, v in checkpoint.items()

0 commit comments

Comments
 (0)