Skip to content

Commit 8cb3662

Browse files
author
yexin
committed
fix bug when CUDA_VISABLE_DEVICES is set
1 parent 5379ead commit 8cb3662

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

checkpoint_engine/distributed/hccl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def init_process_group(
205205
self.port = port
206206
self.rank = rank
207207
self.world_size = world_size
208-
self.device = torch.device("npu", rank)
208+
self.device = torch.device("npu", torch.npu.current_device())
209209

210210
self.pg = StatelessProcessGroup.create(
211211
host, port, rank, world_size, store_timeout=int(timeout.total_seconds())

checkpoint_engine/distributed/nccl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def init_process_group(
122122
self.port = port
123123
self.rank = rank
124124
self.world_size = world_size
125-
self.device = torch.device("cuda", rank)
125+
self.device = torch.device("cuda", torch.cuda.current_device())
126126

127127
self.pg = StatelessProcessGroup.create(
128128
host, port, rank, world_size, store_timeout=int(timeout.total_seconds())

0 commit comments

Comments
 (0)