@@ -102,7 +102,7 @@ def test_quorum_happy(self, client_mock: MagicMock) -> None:
102
102
self .assertEqual (manager ._step , 0 )
103
103
self .assertEqual (manager .batches_committed (), 0 )
104
104
105
- manager .step ()
105
+ manager .start_step ()
106
106
manager .allreduce_grad (torch .tensor ([1.0 ])).wait ()
107
107
self .assertEqual (len (manager ._pending_work ), 1 )
108
108
self .assertTrue (manager .should_commit ())
@@ -113,7 +113,7 @@ def test_quorum_happy(self, client_mock: MagicMock) -> None:
113
113
# pyre-ignore[16]: _pg is mocked
114
114
self .assertEqual (manager ._pg .allreduce .call_count , 1 )
115
115
116
- manager .step ()
116
+ manager .start_step ()
117
117
self .assertEqual (manager .batches_committed (), 2 )
118
118
119
119
@patch ("torchft.manager.ManagerClient" , autospec = True )
@@ -140,7 +140,7 @@ def test_quorum_heal_sync(self, client_mock: MagicMock) -> None:
140
140
self .assertEqual (manager ._quorum_id , - 1 )
141
141
self .assertEqual (manager ._step , 0 )
142
142
143
- manager .step ()
143
+ manager .start_step ()
144
144
manager .allreduce_grad (torch .tensor ([1.0 ])).wait ()
145
145
self .assertFalse (manager ._healing )
146
146
self .assertTrue (manager .is_participating ())
@@ -182,7 +182,7 @@ def test_quorum_heal_async_not_enough_participants(
182
182
self .assertEqual (manager ._quorum_id , - 1 )
183
183
self .assertEqual (manager ._step , 0 )
184
184
185
- manager .step ()
185
+ manager .start_step ()
186
186
assert manager ._quorum_future is not None
187
187
manager ._quorum_future .result ()
188
188
self .assertTrue (manager ._healing )
@@ -206,7 +206,7 @@ def test_quorum_heal_async_not_enough_participants(
206
206
self .assertEqual (self .load_state_dict .call_count , 1 )
207
207
208
208
# failed to commit so no step
209
- manager .step ()
209
+ manager .start_step ()
210
210
self .assertEqual (manager ._step , 20 )
211
211
self .assertEqual (manager .batches_committed (), 0 )
212
212
@@ -234,7 +234,7 @@ def test_quorum_heal_async_zero_grad(self, client_mock: MagicMock) -> None:
234
234
self .assertEqual (manager ._quorum_id , - 1 )
235
235
self .assertEqual (manager ._step , 0 )
236
236
237
- manager .step ()
237
+ manager .start_step ()
238
238
assert manager ._quorum_future is not None
239
239
manager ._quorum_future .result ()
240
240
self .assertTrue (manager ._healing )
@@ -256,7 +256,7 @@ def test_quorum_heal_async_zero_grad(self, client_mock: MagicMock) -> None:
256
256
257
257
self .assertEqual (self .load_state_dict .call_count , 1 )
258
258
259
- manager .step ()
259
+ manager .start_step ()
260
260
self .assertEqual (manager ._step , 21 )
261
261
self .assertEqual (manager .batches_committed (), 1 )
262
262
@@ -280,7 +280,7 @@ def test_allreduce_error(self, client_mock: MagicMock) -> None:
280
280
self .assertEqual (manager ._quorum_id , - 1 )
281
281
self .assertEqual (manager ._step , 0 )
282
282
283
- manager .step ()
283
+ manager .start_step ()
284
284
manager .allreduce_grad (torch .tensor ([1.0 ])).wait ()
285
285
# pyre-ignore[16]: _pg is mocked
286
286
self .assertEqual (manager ._pg .allreduce .call_count , 1 )
@@ -314,7 +314,7 @@ def test_allreduce_error(self, client_mock: MagicMock) -> None:
314
314
2 , # max_world_size
315
315
False , # heal
316
316
)
317
- manager .step ()
317
+ manager .start_step ()
318
318
319
319
self .assertFalse (manager ._errored )
320
320
@@ -343,7 +343,7 @@ def test_allreduce_error(self, client_mock: MagicMock) -> None:
343
343
False , # heal
344
344
)
345
345
346
- manager .step ()
346
+ manager .start_step ()
347
347
manager .allreduce_grad (torch .tensor ([1.0 ])).wait ()
348
348
self .assertTrue (manager .should_commit ())
349
349
@@ -375,13 +375,13 @@ def test_quorum_fixed_world_size(self, client_mock: MagicMock) -> None:
375
375
self .assertEqual (manager ._step , 0 )
376
376
self .assertEqual (manager .batches_committed (), 0 )
377
377
378
- manager .step ()
378
+ manager .start_step ()
379
379
manager .allreduce_grad (torch .tensor ([1.0 ])).wait ()
380
380
381
381
self .assertEqual (manager .is_participating (), rank != 2 )
382
382
self .assertEqual (manager .num_participants (), 2 )
383
383
384
- manager .step ()
384
+ manager .start_step ()
385
385
self .assertEqual (manager .batches_committed (), 2 )
386
386
387
387
@patch ("torchft.manager.ManagerClient" , autospec = True )
0 commit comments