Skip to content

Commit

Permalink
manager_integ_tests: added multi rank recovery and sync tests (#40)
Browse files Browse the repository at this point in the history
  • Loading branch information
d4l3k authored Dec 14, 2024
1 parent a52d746 commit 58a436d
Show file tree
Hide file tree
Showing 5 changed files with 225 additions and 53 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ dependencies = [
dev = [
"pytest",
"black",
"pyre-check"
"pyre-check",
"parameterized"
]

[tool.maturin]
Expand Down
6 changes: 6 additions & 0 deletions torchft/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ def __init__(self, manager: "Manager", module: nn.Module, **kwargs: object) -> N
super().__init__(
module,
process_group=pg,
# HACK: This forces the reducer to never rebuild buckets.
# The reducer normally rebuilds the buckets after the first training
# step which can improve performance but is incompatible with
# torchft as it will cause the buckets to diverge for recovering
# replicas.
find_unused_parameters=True,
# pyre-fixme[6]: got object
**kwargs,
)
Expand Down
70 changes: 54 additions & 16 deletions torchft/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,9 @@
if TYPE_CHECKING:
from torchft.process_group import ProcessGroup

logger: logging.Logger = logging.getLogger(__name__)

MANAGER_ADDR_KEY: str = "manager_addr"
MANAGER_DEFAULT_PORT: int = int(os.environ.get("TORCHFT_MANAGER_PORT", 29511))
REPLICA_ID_KEY: str = "replica_id"

T = TypeVar("T")

Expand Down Expand Up @@ -132,7 +131,9 @@ def _manager_state_dict() -> Dict[str, T]:
}

self._ckpt_server = CheckpointServer[Dict[str, T]](_manager_state_dict)
self._executor = ThreadPoolExecutor(max_workers=1)
self._executor = ThreadPoolExecutor(
max_workers=1, thread_name_prefix="async_quorum"
)
self._quorum_future: Optional[concurrent.futures.Future] = None

self._store = TCPStore(
Expand Down Expand Up @@ -163,10 +164,16 @@ def _manager_state_dict() -> Dict[str, T]:
)

self._store.set(MANAGER_ADDR_KEY, addr)
self._store.set(REPLICA_ID_KEY, replica_id)

addr = self._store.get(MANAGER_ADDR_KEY).decode("utf-8")
self._client = ManagerClient(addr, timeout=timeout)

replica_id = self._store.get(REPLICA_ID_KEY).decode("utf-8")
self._logger = _ManagerLogger(
manager=self, replica_id=replica_id or "", rank=rank
)

self._step = 0
self._quorum_id = -1
self._errored: Optional[Exception] = None
Expand Down Expand Up @@ -230,6 +237,7 @@ def callback(
) -> torch.Tensor:
nonlocal grad

# check for exceptions
fut.value()

grad /= self.num_participants()
Expand All @@ -241,7 +249,9 @@ def callback(
return fut

except Exception as e:
logger.exception(f"got exception in all reduce -- skipping remaining: {e}")
self._logger.exception(
f"got exception in all reduce -- skipping remaining: {e}"
)
self.report_error(e)

fut = torch.futures.Future() # pyre-fixme[29]: not a function
Expand Down Expand Up @@ -294,7 +304,9 @@ def callback(
try:
return fut.value()
except Exception as e:
logger.exception(f"got exception in future -- skipping remaining: {e}")
self._logger.exception(
f"got exception in future -- skipping remaining: {e}"
)
self.report_error(e)
return default

Expand Down Expand Up @@ -328,12 +340,13 @@ def step(self) -> None:
if not self._use_async_quorum:
self._quorum_future.result()

# eagerly apply pending state_dict so we can run the forwards pass
self._apply_pending_state_dict()
if self._healing:
# eagerly apply pending state_dict so we can run the forwards pass
self._apply_pending_state_dict()

# we are forcing healing at the beginning so we're in a good state
# and don't need to zero_grad
self._healing = False
# we are forcing healing at the beginning so we're in a good state
# and don't need to zero_grad
self._healing = False

def _async_quorum(self) -> None:
(
Expand Down Expand Up @@ -374,21 +387,23 @@ def _async_quorum(self) -> None:
self._participating_rank = None

if quorum_id != self._quorum_id:
logger.info(f"{replica_rank=} reconfiguring for quorum_id {quorum_id}")
store_prefixed_addr = f"{store_address}/torchft/{quorum_id}/{self._rank}"

self._logger.info(f"reconfiguring for {quorum_id=} {store_prefixed_addr=}")
# We use the replica rank and world as we want all replicas in the PG.
self._pg.configure(store_prefixed_addr, replica_rank, replica_world_size)
self._quorum_id = quorum_id

# See manager.rs for healing conditions
if heal:
self._healing = True
logger.info(f"{replica_rank}= healing required")

logger.info(f"fetching checkpoint server address from {address}")
self._logger.info(
f"healing required, fetching checkpoint server address from {address=} {max_step=}"
)
primary_client = ManagerClient(address, timeout=self._timeout)
checkpoint_server_address = primary_client.checkpoint_address(self._rank)

self._logger.info(f"fetching checkpoint from {checkpoint_server_address=}")
self._pending_state_dict = CheckpointServer.load_from_address(
checkpoint_server_address
)
Expand All @@ -406,8 +421,9 @@ def _apply_pending_state_dict(self) -> None:
assert self._quorum_future is not None, "must call step before should_commit"
self._quorum_future.result()

assert self._pending_state_dict is not None, "checkpoint was not staged"
self._logger.info("applying pending state dict")

assert self._pending_state_dict is not None, "checkpoint was not staged"
self._load_state_dict(self._pending_state_dict["user"])
self._pending_state_dict = None

Expand Down Expand Up @@ -450,7 +466,7 @@ def should_commit(self) -> bool:
should_commit = self._client.should_commit(
self._rank, self._step, local_should_commit
)
logger.info(
self._logger.info(
f"should_commit={should_commit} enough_replicas={enough_replicas}, errored={self._errored}"
)

Expand Down Expand Up @@ -534,3 +550,25 @@ def is_participating(self) -> bool:
assert self._use_async_quorum
return False
return True


class _ManagerLogger:
def __init__(self, manager: Manager, replica_id: str, rank: int) -> None:
self._logger: logging.Logger = logging.getLogger(__name__)
self._replica_id = replica_id
self._rank = rank
self._manager = manager

def prefix(self) -> str:
return (
f"[{self._replica_id}/{self._rank} - step {self._manager.current_step()}]"
)

def info(self, msg: str) -> None:
self._logger.info(f"{self.prefix()} {msg}")

def warn(self, msg: str) -> None:
self._logger.warn(f"{self.prefix()} {msg}")

def exception(self, msg: str) -> None:
self._logger.exception(f"{self.prefix()} {msg}")
Loading

0 comments on commit 58a436d

Please sign in to comment.