@@ -102,7 +102,7 @@ def size(self) -> int:
102
102
def getBackendName (self ) -> str :
103
103
raise NotImplementedError ("not implemented" )
104
104
105
- def register (self , name : str ) -> None :
105
+ def register (self , name : str ) -> BaseProcessGroup :
106
106
"""
107
107
Registers the process group with the global registry. This enables usage
108
108
with things like functional_collectives which are compilable.
@@ -113,32 +113,42 @@ def register(self, name: str) -> None:
113
113
name: name must be a unique name for this process group
114
114
"""
115
115
116
- self ._group_name = f"{ self .getBackendName ()} :{ name } "
117
- _register_process_group (self .group_name , self )
116
+ group_name = f"{ self .getBackendName ()} :{ name } "
118
117
119
- # This is needed for DeviceMesh to work
118
+ # This is needed for DeviceMesh and functional collectives to work.
120
119
# Resizable worlds don't fit well into DeviceMesh so we register a world
121
120
# size 1 PG.
122
- _world .pg_map [self ] = (None , None )
123
- _world .pg_names [self ] = self ._group_name
124
- _world .pg_to_tag [self ] = self ._group_name
125
- _world .tags_to_pg .setdefault (self ._group_name , []).append (self )
126
- # these PGs can be resized so we lie about the rank mapping
127
- _world .pg_group_ranks [self ] = {get_rank (): 0 }
121
+
122
+ def create_pg (
123
+ prefix_store : PrefixStore , rank : int , world_size : int , timeout : float
124
+ ) -> ProcessGroup :
125
+ return self
126
+
127
+ dist .Backend .register_backend (group_name , create_pg )
128
+
129
+ return dist .new_group (
130
+ ranks = [dist .get_rank ()],
131
+ backend = group_name ,
132
+ group_desc = group_name ,
133
+ timeout = timedelta (seconds = 60.0 ), # this timeout isn't used
134
+ )
128
135
129
136
@property
130
137
def group_name (self ) -> str :
131
138
if self ._group_name is None :
132
139
raise ValueError ("ProcessGroup name not set" )
133
140
return self ._group_name
134
141
142
+ def _set_group_name (self , name : str ) -> None :
143
+ self ._group_name = name
144
+
135
145
def unregister (self ) -> None :
136
146
"""
137
147
Unregisters the process group with the global registry.
138
148
139
149
Must be registered first.
140
150
"""
141
- _unregister_process_group (self . group_name )
151
+ dist . destroy_process_group (self )
142
152
143
153
144
154
class ProcessGroupWrapper (ProcessGroup ):
0 commit comments