@@ -208,6 +208,37 @@ def test_baby_gloo_timeout(self) -> None:
208
208
with self .assertRaisesRegex (TimeoutError , "timed out after 0.01 seconds" ):
209
209
a .configure (store_addr , 0 , 2 )
210
210
211
+ def test_reconfigure_baby_process_group (self ) -> None :
212
+ store = TCPStore (
213
+ host_name = "localhost" , port = 0 , is_master = True , wait_for_workers = False
214
+ )
215
+ store_addr = f"localhost:{ store .port } /prefix"
216
+
217
+ a = ProcessGroupBabyGloo ()
218
+ a .configure (store_addr , 0 , 1 )
219
+ futher_thread_1 = a ._future_thread
220
+ futher_queue_1 = a ._future_queue
221
+ p_1 = a ._p
222
+
223
+ store_addr = f"localhost:{ store .port } /prefix2"
224
+ a .configure (store_addr , 0 , 1 )
225
+ futher_thread_2 = a ._future_thread
226
+ futher_queue_2 = a ._future_queue
227
+ p_2 = a ._p
228
+
229
+ self .assertNotEqual (futher_thread_1 , futher_thread_2 )
230
+ self .assertNotEqual (futher_queue_1 , futher_queue_2 )
231
+ self .assertNotEqual (p_1 , p_2 )
232
+
233
+ self .assertFalse (futher_thread_1 .is_alive ())
234
+ self .assertTrue (futher_queue_1 ._closed )
235
+ self .assertFalse (p_1 .is_alive ())
236
+
237
+ self .assertTrue (futher_thread_2 .is_alive ())
238
+ self .assertFalse (futher_queue_2 ._closed )
239
+ self .assertTrue (p_2 .is_alive ())
240
+
241
+
211
242
def test_dummy (self ) -> None :
212
243
pg = ProcessGroupDummy (0 , 1 )
213
244
m = nn .Linear (3 , 4 )
0 commit comments