Skip to content

Commit 58e6028

Browse files
committed
Typing
1 parent ff0f31f commit 58e6028

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

torchft/manager.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -534,8 +534,11 @@ def _apply_pending_state_dict(self) -> None:
534534

535535
self._logger.info("applying pending state dict")
536536

537-
assert self._pending_state_dict is not None, "checkpoint was not staged"
537+
assert (
538+
self._load_state_dict is not None
539+
), "user load_state_dict is not initialized."
538540
self._load_state_dict(self._pending_state_dict["user"])
541+
assert self._pending_state_dict is not None, "checkpoint was not staged"
539542
self._pending_state_dict = None
540543
self._logger.info("Loaded state dict.")
541544

@@ -607,6 +610,7 @@ def load_state_dict(self, state_dict: Dict[str, int]) -> None:
607610
self._batches_committed = state_dict["batches_committed"]
608611

609612
def _manager_state_dict(self) -> Dict[str, object]:
613+
assert self._user_state_dict is not None, "user state_dict is not initialized."
610614
return {
611615
"user": self._user_state_dict(),
612616
"torchft": self.state_dict(),

0 commit comments

Comments
 (0)