5858 AllToAllOptions ,
5959 BarrierOptions ,
6060 BroadcastOptions ,
61+ GroupName ,
6162 ReduceOp ,
6263 ReduceScatterOptions ,
6364 Work ,
@@ -133,7 +134,7 @@ def __init__(self, *args: object, **kwargs: object) -> None:
133134 # pyre-fixme[6]: got object
134135 super ().__init__ (* args , ** kwargs )
135136
136- self ._group_name : Optional [str ] = None
137+ self ._group_name : Optional [GroupName ] = None
137138
138139 # pyre-fixme[14]: inconsistent override
139140 def allgather (
@@ -313,7 +314,7 @@ def size(self) -> int:
313314 def getBackendName (self ) -> str :
314315 raise NotImplementedError ("not implemented" )
315316
316- def _register (self , name : str ) -> str :
317+ def _register (self , name : str ) -> GroupName :
317318 group_name = f"{ self .getBackendName ()} :{ name } "
318319
319320 # This is needed for DeviceMesh and functional collectives to work.
@@ -332,7 +333,7 @@ def create_pg(
332333 devices .append ("xpu" )
333334 dist .Backend .register_backend (group_name , create_pg , devices = devices )
334335
335- return group_name
336+ return GroupName ( group_name )
336337
337338 def register (self , name : str ) -> "ProcessGroup" :
338339 """
@@ -355,12 +356,12 @@ def register(self, name: str) -> "ProcessGroup":
355356 )
356357
357358 @property
358- def group_name (self ) -> str :
359+ def group_name (self ) -> GroupName :
359360 if self ._group_name is None :
360361 raise ValueError ("ProcessGroup name not set" )
361362 return self ._group_name
362363
363- def _set_group_name (self , name : str ) -> None :
364+ def _set_group_name (self , name : GroupName ) -> None :
364365 self ._group_name = name
365366
366367 def unregister (self ) -> None :
0 commit comments