Skip to content

Commit a484e4f

Browse files
authored
manager: rename start_step to start_quorum and move step changes to should_commit (#46)
1 parent 78c5721 commit a484e4f

File tree

5 files changed

+164
-47
lines changed

5 files changed

+164
-47
lines changed

src/manager.rs

+68-2
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,8 @@ impl ManagerService for Arc<Manager> {
239239
.await
240240
.map_err(|e| Status::internal(e.to_string()))?;
241241

242-
let participants = &quorum.participants;
242+
let mut participants = quorum.participants.clone();
243+
participants.sort_by(|a, b| a.replica_id.cmp(&b.replica_id));
243244

244245
let mut replica_rank = 10000000000;
245246
for (i, p) in participants.iter().enumerate() {
@@ -266,7 +267,7 @@ impl ManagerService for Arc<Manager> {
266267
// Decide whether we should be healing:
267268
// 1. if we're not at the max step
268269
// 2. if everyone is at the first step and we're not the primary
269-
let heal = max_step != req.step || max_step == 1 && primary.replica_id != self.replica_id;
270+
let heal = max_step != req.step || max_step == 0 && primary.replica_id != self.replica_id;
270271
if heal {
271272
info!(
272273
"healing is required step={}, max_step={}",
@@ -475,4 +476,69 @@ mod tests {
475476

476477
Ok(())
477478
}
479+
480+
#[tokio::test]
481+
async fn test_get_quorum_heal_first_step() -> Result<()> {
482+
let lighthouse = Lighthouse::new(LighthouseOpt {
483+
bind: "[::]:0".to_string(),
484+
join_timeout_ms: 100,
485+
min_replicas: 2,
486+
quorum_tick_ms: 100,
487+
})
488+
.await?;
489+
let lighthouse_fut = tokio::spawn(lighthouse.clone().run());
490+
491+
let mut manager_futs: Vec<tokio::task::JoinHandle<Result<ManagerQuorumResponse>>> =
492+
Vec::new();
493+
494+
for replica_id in 0..2 {
495+
let lighthouse_addr = lighthouse.address();
496+
manager_futs.push(tokio::spawn(async move {
497+
let manager = Manager::new(
498+
format!("rep_{}", replica_id),
499+
lighthouse_addr,
500+
"addr".to_string(),
501+
"[::]:0".to_string(),
502+
"store_addr".to_string(),
503+
1, // world size
504+
)
505+
.await?;
506+
let manager_fut = tokio::spawn(manager.clone().run());
507+
508+
let mut client =
509+
manager_client_new(manager.address(), Duration::from_secs(10)).await?;
510+
511+
let request = tonic::Request::new(ManagerQuorumRequest {
512+
rank: 0,
513+
step: 0,
514+
checkpoint_server_addr: "addr".to_string(),
515+
});
516+
517+
let result = client.quorum(request).await?.into_inner();
518+
519+
manager_fut.abort();
520+
521+
Ok(result)
522+
}));
523+
}
524+
525+
let resp_a = manager_futs.swap_remove(0).await??;
526+
let resp_b = manager_futs.swap_remove(0).await??;
527+
528+
lighthouse_fut.abort();
529+
530+
assert_eq!(resp_a.quorum_id, 1);
531+
assert_eq!(resp_a.max_step, 0);
532+
assert_eq!(resp_a.replica_rank, 0);
533+
assert_eq!(resp_a.replica_world_size, 2);
534+
assert_eq!(resp_a.heal, false);
535+
536+
assert_eq!(resp_b.quorum_id, 1);
537+
assert_eq!(resp_b.max_step, 0);
538+
assert_eq!(resp_b.replica_rank, 1);
539+
assert_eq!(resp_b.replica_world_size, 2);
540+
assert_eq!(resp_b.heal, true);
541+
542+
Ok(())
543+
}
478544
}

torchft/manager.py

+32-14
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,6 @@ def _manager_state_dict() -> Dict[str, T]:
182182
self._batches_committed = 0
183183

184184
# first step is 1
185-
self._should_step = True
186185
self._participating_rank: Optional[int] = None
187186
self._participating_world_size: int = 0
188187

@@ -218,8 +217,7 @@ def allreduce_grad(self, grad: torch.Tensor) -> torch.futures.Future[torch.Tenso
218217
fut.set_result(grad)
219218
return fut
220219

221-
assert self._quorum_future is not None, "must call step before allreduce_grad"
222-
self._quorum_future.result()
220+
self.wait_quorum()
223221

224222
if not self.is_participating():
225223
grad.zero_()
@@ -315,21 +313,28 @@ def callback(
315313
self._pending_work.append(cast(torch.futures.Future[object], fut))
316314
return fut
317315

318-
def start_step(self) -> None:
316+
def start_quorum(self, allow_heal: bool = True) -> None:
319317
"""
320318
.. note::
321319
We recommend using the :py:class:`torchft.optim.OptimizerWrapper` instead of calling this directly.
322320
323321
Computes a new quorum (potentially asynchronously) and readies the
324322
manager for a new step.
325323
326-
Must be called before the forwards pass of each step for best
324+
It's best practice to call this before the forwards pass of each step for
327325
performance as computing quorum may take some time.
326+
327+
If allow_heal is set, the manager will attempt to heal either
328+
synchronously before returning or asynchronously prior to any network
329+
calls.
330+
331+
Args:
332+
allow_heal: whether to allow healing at the beginning of the step
328333
"""
329334

330-
if self._should_step:
331-
self._step += 1
332-
self._batches_committed += self.num_participants()
335+
# wait for previous quorum to complete
336+
if self._quorum_future is not None:
337+
self._quorum_future.result()
333338

334339
self._errored = None
335340
self._healing = False
@@ -338,9 +343,9 @@ def start_step(self) -> None:
338343
# TODO: we should really be wrapping this whole section in a try-except
339344
# block to allow gracefully recovering from issues in PG setup and quorum.
340345

341-
self._quorum_future = self._executor.submit(self._async_quorum)
346+
self._quorum_future = self._executor.submit(self._async_quorum, allow_heal)
342347
if not self._use_async_quorum:
343-
self._quorum_future.result()
348+
self.wait_quorum()
344349

345350
if self._healing:
346351
# eagerly apply pending state_dict so we can run the forwards pass
@@ -350,7 +355,18 @@ def start_step(self) -> None:
350355
# and don't need to zero_grad
351356
self._healing = False
352357

353-
def _async_quorum(self) -> None:
358+
def wait_quorum(self) -> None:
359+
"""
360+
Wait for the quorum to complete.
361+
362+
ProcessGroup will be in a healthy state after this returns.
363+
"""
364+
assert (
365+
self._quorum_future is not None
366+
), "must call start_quorum before wait_quorum"
367+
self._quorum_future.result()
368+
369+
def _async_quorum(self, allow_heal: bool) -> None:
354370
(
355371
quorum_id,
356372
replica_rank,
@@ -372,7 +388,7 @@ def _async_quorum(self) -> None:
372388
# workers will be healthy.
373389
self._participating_rank, self._participating_world_size = (
374390
(max_rank, max_world_size)
375-
if self._use_async_quorum
391+
if self._use_async_quorum or not allow_heal
376392
else (replica_rank, replica_world_size)
377393
)
378394

@@ -397,7 +413,7 @@ def _async_quorum(self) -> None:
397413
self._quorum_id = quorum_id
398414

399415
# See manager.rs for healing conditions
400-
if heal:
416+
if heal and allow_heal:
401417
self._healing = True
402418
self._logger.info(
403419
f"healing required, fetching checkpoint server address from {address=} {max_step=}"
@@ -475,7 +491,9 @@ def should_commit(self) -> bool:
475491
self._ckpt_server.disallow_checkpoint()
476492

477493
# decide whether we're in a healthy state to increase the step count
478-
self._should_step = should_commit
494+
if should_commit:
495+
self._step += 1
496+
self._batches_committed += self.num_participants()
479497

480498
return should_commit
481499

0 commit comments

Comments
 (0)