Skip to content

Commit 78c5721

Browse files
authored
manager: rename step to start_step + small shutdown fix (#44)
1 parent 8a22dc8 commit 78c5721

File tree

4 files changed

+19
-17
lines changed

4 files changed

+19
-17
lines changed

torchft/manager.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ def shutdown(self) -> None:
193193
self._ckpt_server.shutdown()
194194
if self._manager is not None:
195195
self._manager.shutdown()
196+
self._executor.shutdown()
196197

197198
def allreduce_grad(self, grad: torch.Tensor) -> torch.futures.Future[torch.Tensor]:
198199
"""
@@ -314,15 +315,16 @@ def callback(
314315
self._pending_work.append(cast(torch.futures.Future[object], fut))
315316
return fut
316317

317-
def step(self) -> None:
318+
def start_step(self) -> None:
318319
"""
319320
.. note::
320321
We recommend using the :py:class:`torchft.optim.OptimizerWrapper` instead of calling this directly.
321322
322-
Must be called before the forwards pass of each step.
323-
324323
Computes a new quorum (potentially asynchronously) and readies the
325324
manager for a new step.
325+
326+
Must be called before the forwards pass of each step for best
327+
performance as computing quorum may take some time.
326328
"""
327329

328330
if self._should_step:

torchft/manager_test.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def test_quorum_happy(self, client_mock: MagicMock) -> None:
102102
self.assertEqual(manager._step, 0)
103103
self.assertEqual(manager.batches_committed(), 0)
104104

105-
manager.step()
105+
manager.start_step()
106106
manager.allreduce_grad(torch.tensor([1.0])).wait()
107107
self.assertEqual(len(manager._pending_work), 1)
108108
self.assertTrue(manager.should_commit())
@@ -113,7 +113,7 @@ def test_quorum_happy(self, client_mock: MagicMock) -> None:
113113
# pyre-ignore[16]: _pg is mocked
114114
self.assertEqual(manager._pg.allreduce.call_count, 1)
115115

116-
manager.step()
116+
manager.start_step()
117117
self.assertEqual(manager.batches_committed(), 2)
118118

119119
@patch("torchft.manager.ManagerClient", autospec=True)
@@ -140,7 +140,7 @@ def test_quorum_heal_sync(self, client_mock: MagicMock) -> None:
140140
self.assertEqual(manager._quorum_id, -1)
141141
self.assertEqual(manager._step, 0)
142142

143-
manager.step()
143+
manager.start_step()
144144
manager.allreduce_grad(torch.tensor([1.0])).wait()
145145
self.assertFalse(manager._healing)
146146
self.assertTrue(manager.is_participating())
@@ -182,7 +182,7 @@ def test_quorum_heal_async_not_enough_participants(
182182
self.assertEqual(manager._quorum_id, -1)
183183
self.assertEqual(manager._step, 0)
184184

185-
manager.step()
185+
manager.start_step()
186186
assert manager._quorum_future is not None
187187
manager._quorum_future.result()
188188
self.assertTrue(manager._healing)
@@ -206,7 +206,7 @@ def test_quorum_heal_async_not_enough_participants(
206206
self.assertEqual(self.load_state_dict.call_count, 1)
207207

208208
# failed to commit so no step
209-
manager.step()
209+
manager.start_step()
210210
self.assertEqual(manager._step, 20)
211211
self.assertEqual(manager.batches_committed(), 0)
212212

@@ -234,7 +234,7 @@ def test_quorum_heal_async_zero_grad(self, client_mock: MagicMock) -> None:
234234
self.assertEqual(manager._quorum_id, -1)
235235
self.assertEqual(manager._step, 0)
236236

237-
manager.step()
237+
manager.start_step()
238238
assert manager._quorum_future is not None
239239
manager._quorum_future.result()
240240
self.assertTrue(manager._healing)
@@ -256,7 +256,7 @@ def test_quorum_heal_async_zero_grad(self, client_mock: MagicMock) -> None:
256256

257257
self.assertEqual(self.load_state_dict.call_count, 1)
258258

259-
manager.step()
259+
manager.start_step()
260260
self.assertEqual(manager._step, 21)
261261
self.assertEqual(manager.batches_committed(), 1)
262262

@@ -280,7 +280,7 @@ def test_allreduce_error(self, client_mock: MagicMock) -> None:
280280
self.assertEqual(manager._quorum_id, -1)
281281
self.assertEqual(manager._step, 0)
282282

283-
manager.step()
283+
manager.start_step()
284284
manager.allreduce_grad(torch.tensor([1.0])).wait()
285285
# pyre-ignore[16]: _pg is mocked
286286
self.assertEqual(manager._pg.allreduce.call_count, 1)
@@ -314,7 +314,7 @@ def test_allreduce_error(self, client_mock: MagicMock) -> None:
314314
2, # max_world_size
315315
False, # heal
316316
)
317-
manager.step()
317+
manager.start_step()
318318

319319
self.assertFalse(manager._errored)
320320

@@ -343,7 +343,7 @@ def test_allreduce_error(self, client_mock: MagicMock) -> None:
343343
False, # heal
344344
)
345345

346-
manager.step()
346+
manager.start_step()
347347
manager.allreduce_grad(torch.tensor([1.0])).wait()
348348
self.assertTrue(manager.should_commit())
349349

@@ -375,13 +375,13 @@ def test_quorum_fixed_world_size(self, client_mock: MagicMock) -> None:
375375
self.assertEqual(manager._step, 0)
376376
self.assertEqual(manager.batches_committed(), 0)
377377

378-
manager.step()
378+
manager.start_step()
379379
manager.allreduce_grad(torch.tensor([1.0])).wait()
380380

381381
self.assertEqual(manager.is_participating(), rank != 2)
382382
self.assertEqual(manager.num_participants(), 2)
383383

384-
manager.step()
384+
manager.start_step()
385385
self.assertEqual(manager.batches_committed(), 2)
386386

387387
@patch("torchft.manager.ManagerClient", autospec=True)

torchft/optim.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def state_dict(self) -> object:
4545
return self.optim.state_dict()
4646

4747
def zero_grad(self, set_to_none: bool = True) -> None:
48-
self.manager.step()
48+
self.manager.start_step()
4949
self.optim.zero_grad(set_to_none)
5050

5151
def step(self, closure: Optional[object] = None) -> None:

torchft/optim_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def test_optimizer_wrapper(self) -> None:
3232
optim.load_state_dict(optim.state_dict())
3333

3434
optim.zero_grad()
35-
self.assertEqual(manager.step.call_count, 1)
35+
self.assertEqual(manager.start_step.call_count, 1)
3636

3737
manager.should_commit.return_value = True
3838
optim.step()

0 commit comments

Comments
 (0)