Skip to content

Commit 9e3bbd9

Browse files
Devmate Botfacebook-github-bot
authored andcommitted
Fix for T247806505 ("Your diff, D88512803, broke one test")
Differential Revision: D88632010
1 parent ee51839 commit 9e3bbd9

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

torchft/process_group.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
AllToAllOptions,
5959
BarrierOptions,
6060
BroadcastOptions,
61+
GroupName,
6162
ReduceOp,
6263
ReduceScatterOptions,
6364
Work,
@@ -133,7 +134,7 @@ def __init__(self, *args: object, **kwargs: object) -> None:
133134
# pyre-fixme[6]: got object
134135
super().__init__(*args, **kwargs)
135136

136-
self._group_name: Optional[str] = None
137+
self._group_name: Optional[GroupName] = None
137138

138139
# pyre-fixme[14]: inconsistent override
139140
def allgather(
@@ -313,7 +314,7 @@ def size(self) -> int:
313314
def getBackendName(self) -> str:
314315
raise NotImplementedError("not implemented")
315316

316-
def _register(self, name: str) -> str:
317+
def _register(self, name: str) -> GroupName:
317318
group_name = f"{self.getBackendName()}:{name}"
318319

319320
# This is needed for DeviceMesh and functional collectives to work.
@@ -332,7 +333,7 @@ def create_pg(
332333
devices.append("xpu")
333334
dist.Backend.register_backend(group_name, create_pg, devices=devices)
334335

335-
return group_name
336+
return GroupName(group_name)
336337

337338
def register(self, name: str) -> "ProcessGroup":
338339
"""
@@ -355,12 +356,12 @@ def register(self, name: str) -> "ProcessGroup":
355356
)
356357

357358
@property
358-
def group_name(self) -> str:
359+
def group_name(self) -> GroupName:
359360
if self._group_name is None:
360361
raise ValueError("ProcessGroup name not set")
361362
return self._group_name
362363

363-
def _set_group_name(self, name: str) -> None:
364+
def _set_group_name(self, name: GroupName) -> None:
364365
self._group_name = name
365366

366367
def unregister(self) -> None:

0 commit comments

Comments
 (0)