Skip to content

Commit

Permalink
Change how TorchFT manages user_state_dict (#87)
Browse files Browse the repository at this point in the history
* Change how TorchFT manages user_state_dict

This PR closes some state_dict gaps when integrating with TorchTitan:
1. User state_dict() and load_state_dict() functions can be initialized lazily.
2. Change weights_only to False for torch.load as we may have to load some non-tensor states.
  • Loading branch information
fegin authored Jan 30, 2025
1 parent fa1630d commit 6e4ae38
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 6 deletions.
4 changes: 3 additions & 1 deletion torchft/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,9 @@ def load_from_address(cls, address: str, timeout: timedelta) -> T:
data = f.read()

reader = io.BytesIO(data)
return torch.load(reader, weights_only=True)
# We have to set weights_only to False as there are some non-tensor
# states like lr_scheduler.
return torch.load(reader, weights_only=False)

def address(self) -> str:
"""
Expand Down
19 changes: 14 additions & 5 deletions torchft/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ class Manager:
def __init__(
self,
pg: "ProcessGroup",
load_state_dict: Callable[[T], None],
state_dict: Callable[[], T],
load_state_dict: Optional[Callable[[T], None]],
state_dict: Optional[Callable[[], T]],
min_replica_size: int,
use_async_quorum: bool = True,
timeout: timedelta = timedelta(seconds=60),
Expand Down Expand Up @@ -144,7 +144,7 @@ def __init__(
transfering checkpoints to recovering replicas
"""
self._load_state_dict = load_state_dict
self._state_dict = state_dict
self._user_state_dict = state_dict
self._pending_state_dict: Optional[Dict[str, object]] = None
self._use_async_quorum = use_async_quorum
self._timeout = timeout
Expand All @@ -159,8 +159,6 @@ def __init__(
world_size = world_size or int(os.environ["WORLD_SIZE"])
self._min_replica_size = min_replica_size

self._user_state_dict = state_dict

if checkpoint_transport is None:
checkpoint_transport = CheckpointServer[Dict[str, T]](
timeout=timeout,
Expand Down Expand Up @@ -226,6 +224,12 @@ def __init__(
self._participating_rank: Optional[int] = None
self._participating_world_size: int = 0

def set_state_dict_fns(
self, load_state_dict: Callable[[T], None], state_dict: Callable[[], T]
) -> None:
self._load_state_dict = load_state_dict
self._user_state_dict = state_dict

def shutdown(self, wait: bool = True) -> None:
"""
Shutdown the manager and checkpoint server.
Expand Down Expand Up @@ -531,8 +535,12 @@ def _apply_pending_state_dict(self) -> None:
self._logger.info("applying pending state dict")

assert self._pending_state_dict is not None, "checkpoint was not staged"
assert (
self._load_state_dict is not None
), "user load_state_dict is not initialized."
self._load_state_dict(self._pending_state_dict["user"])
self._pending_state_dict = None
self._logger.info("Loaded state dict.")

def should_commit(self, timeout: Optional[timedelta] = None) -> bool:
"""
Expand Down Expand Up @@ -602,6 +610,7 @@ def load_state_dict(self, state_dict: Dict[str, int]) -> None:
self._batches_committed = state_dict["batches_committed"]

def _manager_state_dict(self) -> Dict[str, object]:
assert self._user_state_dict is not None, "user state_dict is not initialized."
return {
"user": self._user_state_dict(),
"torchft": self.state_dict(),
Expand Down
31 changes: 31 additions & 0 deletions torchft/manager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,37 @@ def test_state_dict(self, client_mock: MagicMock) -> None:
self.assertEqual(manager.current_step(), 1234)
self.assertEqual(manager.batches_committed(), 2345)

@patch("torchft.manager.ManagerClient", autospec=True)
def test_user_state_dict(self, client_mock: MagicMock) -> None:
manager = self._create_manager()

self.assertEqual(
manager._manager_state_dict(),
{
"user": {},
"torchft": {
"step": 0,
"batches_committed": 0,
},
},
)

manager.set_state_dict_fns(
self.load_state_dict,
lambda: {"new_state": 1},
)

self.assertEqual(
manager._manager_state_dict(),
{
"user": {"new_state": 1},
"torchft": {
"step": 0,
"batches_committed": 0,
},
},
)

@patch("torchft.manager.ManagerClient", autospec=True)
def test_quorum_happy(self, client_mock: MagicMock) -> None:
manager = self._create_manager()
Expand Down

0 comments on commit 6e4ae38

Please sign in to comment.