Skip to content

Commit 58a436d

Browse files
authored
manager_integ_tests: added multi rank recovery and sync tests (#40)
1 parent a52d746 commit 58a436d

File tree

5 files changed

+225
-53
lines changed

5 files changed

+225
-53
lines changed

pyproject.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ dependencies = [
1919
dev = [
2020
"pytest",
2121
"black",
22-
"pyre-check"
22+
"pyre-check",
23+
"parameterized"
2324
]
2425

2526
[tool.maturin]

torchft/ddp.py

+6
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,12 @@ def __init__(self, manager: "Manager", module: nn.Module, **kwargs: object) -> N
5252
super().__init__(
5353
module,
5454
process_group=pg,
55+
# HACK: This forces the reducer to never rebuild buckets.
56+
# The reducer normally rebuilds the buckets after the first training
57+
# step which can improve performance but is incompatible with
58+
# torchft as it will cause the buckets to diverge for recovering
59+
# replicas.
60+
find_unused_parameters=True,
5561
# pyre-fixme[6]: got object
5662
**kwargs,
5763
)

torchft/manager.py

+54-16
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,9 @@
4545
if TYPE_CHECKING:
4646
from torchft.process_group import ProcessGroup
4747

48-
logger: logging.Logger = logging.getLogger(__name__)
49-
5048
MANAGER_ADDR_KEY: str = "manager_addr"
5149
MANAGER_DEFAULT_PORT: int = int(os.environ.get("TORCHFT_MANAGER_PORT", 29511))
50+
REPLICA_ID_KEY: str = "replica_id"
5251

5352
T = TypeVar("T")
5453

@@ -132,7 +131,9 @@ def _manager_state_dict() -> Dict[str, T]:
132131
}
133132

134133
self._ckpt_server = CheckpointServer[Dict[str, T]](_manager_state_dict)
135-
self._executor = ThreadPoolExecutor(max_workers=1)
134+
self._executor = ThreadPoolExecutor(
135+
max_workers=1, thread_name_prefix="async_quorum"
136+
)
136137
self._quorum_future: Optional[concurrent.futures.Future] = None
137138

138139
self._store = TCPStore(
@@ -163,10 +164,16 @@ def _manager_state_dict() -> Dict[str, T]:
163164
)
164165

165166
self._store.set(MANAGER_ADDR_KEY, addr)
167+
self._store.set(REPLICA_ID_KEY, replica_id)
166168

167169
addr = self._store.get(MANAGER_ADDR_KEY).decode("utf-8")
168170
self._client = ManagerClient(addr, timeout=timeout)
169171

172+
replica_id = self._store.get(REPLICA_ID_KEY).decode("utf-8")
173+
self._logger = _ManagerLogger(
174+
manager=self, replica_id=replica_id or "", rank=rank
175+
)
176+
170177
self._step = 0
171178
self._quorum_id = -1
172179
self._errored: Optional[Exception] = None
@@ -230,6 +237,7 @@ def callback(
230237
) -> torch.Tensor:
231238
nonlocal grad
232239

240+
# check for exceptions
233241
fut.value()
234242

235243
grad /= self.num_participants()
@@ -241,7 +249,9 @@ def callback(
241249
return fut
242250

243251
except Exception as e:
244-
logger.exception(f"got exception in all reduce -- skipping remaining: {e}")
252+
self._logger.exception(
253+
f"got exception in all reduce -- skipping remaining: {e}"
254+
)
245255
self.report_error(e)
246256

247257
fut = torch.futures.Future() # pyre-fixme[29]: not a function
@@ -294,7 +304,9 @@ def callback(
294304
try:
295305
return fut.value()
296306
except Exception as e:
297-
logger.exception(f"got exception in future -- skipping remaining: {e}")
307+
self._logger.exception(
308+
f"got exception in future -- skipping remaining: {e}"
309+
)
298310
self.report_error(e)
299311
return default
300312

@@ -328,12 +340,13 @@ def step(self) -> None:
328340
if not self._use_async_quorum:
329341
self._quorum_future.result()
330342

331-
# eagerly apply pending state_dict so we can run the forwards pass
332-
self._apply_pending_state_dict()
343+
if self._healing:
344+
# eagerly apply pending state_dict so we can run the forwards pass
345+
self._apply_pending_state_dict()
333346

334-
# we are forcing healing at the beginning so we're in a good state
335-
# and don't need to zero_grad
336-
self._healing = False
347+
# we are forcing healing at the beginning so we're in a good state
348+
# and don't need to zero_grad
349+
self._healing = False
337350

338351
def _async_quorum(self) -> None:
339352
(
@@ -374,21 +387,23 @@ def _async_quorum(self) -> None:
374387
self._participating_rank = None
375388

376389
if quorum_id != self._quorum_id:
377-
logger.info(f"{replica_rank=} reconfiguring for quorum_id {quorum_id}")
378390
store_prefixed_addr = f"{store_address}/torchft/{quorum_id}/{self._rank}"
391+
392+
self._logger.info(f"reconfiguring for {quorum_id=} {store_prefixed_addr=}")
379393
# We use the replica rank and world as we want all replicas in the PG.
380394
self._pg.configure(store_prefixed_addr, replica_rank, replica_world_size)
381395
self._quorum_id = quorum_id
382396

383397
# See manager.rs for healing conditions
384398
if heal:
385399
self._healing = True
386-
logger.info(f"{replica_rank}= healing required")
387-
388-
logger.info(f"fetching checkpoint server address from {address}")
400+
self._logger.info(
401+
f"healing required, fetching checkpoint server address from {address=} {max_step=}"
402+
)
389403
primary_client = ManagerClient(address, timeout=self._timeout)
390404
checkpoint_server_address = primary_client.checkpoint_address(self._rank)
391405

406+
self._logger.info(f"fetching checkpoint from {checkpoint_server_address=}")
392407
self._pending_state_dict = CheckpointServer.load_from_address(
393408
checkpoint_server_address
394409
)
@@ -406,8 +421,9 @@ def _apply_pending_state_dict(self) -> None:
406421
assert self._quorum_future is not None, "must call step before should_commit"
407422
self._quorum_future.result()
408423

409-
assert self._pending_state_dict is not None, "checkpoint was not staged"
424+
self._logger.info("applying pending state dict")
410425

426+
assert self._pending_state_dict is not None, "checkpoint was not staged"
411427
self._load_state_dict(self._pending_state_dict["user"])
412428
self._pending_state_dict = None
413429

@@ -450,7 +466,7 @@ def should_commit(self) -> bool:
450466
should_commit = self._client.should_commit(
451467
self._rank, self._step, local_should_commit
452468
)
453-
logger.info(
469+
self._logger.info(
454470
f"should_commit={should_commit} enough_replicas={enough_replicas}, errored={self._errored}"
455471
)
456472

@@ -534,3 +550,25 @@ def is_participating(self) -> bool:
534550
assert self._use_async_quorum
535551
return False
536552
return True
553+
554+
555+
class _ManagerLogger:
556+
def __init__(self, manager: Manager, replica_id: str, rank: int) -> None:
557+
self._logger: logging.Logger = logging.getLogger(__name__)
558+
self._replica_id = replica_id
559+
self._rank = rank
560+
self._manager = manager
561+
562+
def prefix(self) -> str:
563+
return (
564+
f"[{self._replica_id}/{self._rank} - step {self._manager.current_step()}]"
565+
)
566+
567+
def info(self, msg: str) -> None:
568+
self._logger.info(f"{self.prefix()} {msg}")
569+
570+
def warn(self, msg: str) -> None:
571+
self._logger.warn(f"{self.prefix()} {msg}")
572+
573+
def exception(self, msg: str) -> None:
574+
self._logger.exception(f"{self.prefix()} {msg}")

0 commit comments

Comments
 (0)