Skip to content

Commit 12cb82b

Browse files
authored
[Reland] Reinitialize ModelWrapper.cache_state_dict because set_model_state_dict has side effects (#977)
Reland #971 due to ghstack issues.
1 parent 6ff539a commit 12cb82b

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

torchtitan/components/checkpoint.py

+5
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,11 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
6565
options=StateDictOptions(strict=False),
6666
)
6767
list(map(func, self.model))
68+
# `set_model_state_dict()` does change the keys of the input state_dict,
69+
# we will need to reinitialize the cache_state_dict.
70+
self.cache_state_dict = {
71+
k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items()
72+
}
6873

6974

7075
class Terminate:

0 commit comments

Comments
 (0)