Skip to content

Commit c7a7f1e

Browse files
authored
Specifiy the devices when registering the backend to avoid warnings (#16)
pytest process_group_test.py
1 parent 829d26c commit c7a7f1e

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

torchft/process_group.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,11 @@ def create_pg(
124124
) -> ProcessGroup:
125125
return self
126126

127-
dist.Backend.register_backend(group_name, create_pg)
127+
if torch.cuda.is_available():
128+
devices = ["cuda", "cpu"]
129+
else:
130+
devices = ["cpu"]
131+
dist.Backend.register_backend(group_name, create_pg, devices=devices)
128132

129133
return dist.new_group(
130134
ranks=[dist.get_rank()],

0 commit comments

Comments
 (0)