Skip to content

Commit be4ff50

Browse files
fix state dict hook for early fusion models (#2317)
Co-authored-by: JessicaZhong <[email protected]>
1 parent d3b39cf commit be4ff50

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

torchtune/modules/model_fusion/_early_fusion.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,11 @@ def _state_dict_hook(module, state_dict, prefix, *args, **kwargs):
137137
[!Note] This update changes the order of the OrderedDict
138138
"""
139139
for n, p in module.tok_embeddings.named_parameters():
140-
state_dict[f"{prefix}decoder.tok_embeddings.{n}"] = p
141-
del state_dict[f"{prefix}tok_embeddings.{n}"]
140+
orig_key = f"{prefix}tok_embeddings.{n}"
141+
if orig_key in state_dict:
142+
# preserve the original tensor with its requires_grad state
143+
state_dict[f"{prefix}decoder.tok_embeddings.{n}"] = state_dict[orig_key]
144+
del state_dict[orig_key]
142145

143146
@staticmethod
144147
def _load_state_dict_hook(module, state_dict, prefix, *args, **kwargs):

0 commit comments

Comments
 (0)