Skip to content

Commit 6e804c4

Browse files
authored
fix checkpoint conversion for tied weights llama3 (#1370)
Small fix and tests for llama3 te > hf conversion with tied weights Signed-off-by: Peter St. John <[email protected]>
1 parent 5ff89f0 commit 6e804c4

File tree

5 files changed

+31
-6
lines changed

5 files changed

+31
-6
lines changed

bionemo-recipes/models/amplify/src/amplify/state.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def scale_weights(ctx):
158158
_params[name] = nn.Parameter(target_param, requires_grad=param.requires_grad)
159159
target_state.pop(name)
160160
else:
161-
print(f"Unexpected key: {name} not in checkpoint but in model.")
161+
print(f"Unexpected key: {name} not in target model but is in source model.")
162162

163163
for key, val in _params.items():
164164
_module, _key = target, key
@@ -190,7 +190,7 @@ def scale_weights(ctx):
190190
keys = list(filter(lambda x: x is not None and not x.endswith("_extra_state"), target_state.keys()))
191191
keys = [key for key in keys if key not in state_dict_ignored_entries]
192192
if len(keys) != 0:
193-
raise RuntimeError(f"Additional keys: {keys} in checkpoint but not in model.")
193+
raise RuntimeError(f"Additional keys: {keys} in target model but not in source model.")
194194

195195
if hasattr(target, "tie_weights"):
196196
target.tie_weights()

bionemo-recipes/models/esm2/src/esm/state.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def scale_weights(ctx):
158158
_params[name] = nn.Parameter(target_param, requires_grad=param.requires_grad)
159159
target_state.pop(name)
160160
else:
161-
print(f"Unexpected key: {name} not in checkpoint but in model.")
161+
print(f"Unexpected key: {name} not in target model but is in source model.")
162162

163163
for key, val in _params.items():
164164
_module, _key = target, key
@@ -190,7 +190,7 @@ def scale_weights(ctx):
190190
keys = list(filter(lambda x: x is not None and not x.endswith("_extra_state"), target_state.keys()))
191191
keys = [key for key in keys if key not in state_dict_ignored_entries]
192192
if len(keys) != 0:
193-
raise RuntimeError(f"Additional keys: {keys} in checkpoint but not in model.")
193+
raise RuntimeError(f"Additional keys: {keys} in target model but not in source model.")
194194

195195
if hasattr(target, "tie_weights"):
196196
target.tie_weights()

bionemo-recipes/models/llama3/convert.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ def convert_llama_te_to_hf(model_te: NVLlamaForCausalLM, **config_kwargs) -> Lla
126126
fn=state.TransformFns.split_fc1,
127127
),
128128
],
129+
state_dict_ignored_entries=model_hf._tied_weights_keys,
129130
)
130131

131132
output_model.model.rotary_emb.inv_freq = model_te.model.rotary_emb.inv_freq.clone()

bionemo-recipes/models/llama3/state.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def scale_weights(ctx):
158158
_params[name] = nn.Parameter(target_param, requires_grad=param.requires_grad)
159159
target_state.pop(name)
160160
else:
161-
print(f"Unexpected key: {name} not in checkpoint but in model.")
161+
print(f"Unexpected key: {name} not in target model but is in source model.")
162162

163163
for key, val in _params.items():
164164
_module, _key = target, key
@@ -190,7 +190,7 @@ def scale_weights(ctx):
190190
keys = list(filter(lambda x: x is not None and not x.endswith("_extra_state"), target_state.keys()))
191191
keys = [key for key in keys if key not in state_dict_ignored_entries]
192192
if len(keys) != 0:
193-
raise RuntimeError(f"Additional keys: {keys} in checkpoint but not in model.")
193+
raise RuntimeError(f"Additional keys: {keys} in target model but not in source model.")
194194

195195
if hasattr(target, "tie_weights"):
196196
target.tie_weights()

bionemo-recipes/models/llama3/tests/test_convert.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,18 @@ def test_convert_hf_to_te_with_bf16():
7272
convert_llama_hf_to_te(model_hf)
7373

7474

75+
def test_convert_hf_to_te_with_bf16_tied_weights():
76+
config = AutoConfig.from_pretrained(
77+
"nvidia/Llama-3.1-8B-Instruct-FP8",
78+
dtype=torch.bfloat16,
79+
num_hidden_layers=2,
80+
tie_word_embeddings=True,
81+
)
82+
model_hf = LlamaForCausalLM(config)
83+
model_hf.to(dtype=torch.bfloat16) # I think the original llama3 model doesn't initialize in bf16.
84+
convert_llama_hf_to_te(model_hf)
85+
86+
7587
def test_convert_te_to_hf_with_bf16():
7688
config = NVLlamaConfig.from_pretrained(
7789
"nvidia/Llama-3.1-8B-Instruct-FP8", dtype=torch.bfloat16, num_hidden_layers=2
@@ -81,6 +93,18 @@ def test_convert_te_to_hf_with_bf16():
8193
convert_llama_te_to_hf(model_te)
8294

8395

96+
def test_convert_te_to_hf_with_bf16_tied_weights():
97+
config = NVLlamaConfig.from_pretrained(
98+
"nvidia/Llama-3.1-8B-Instruct-FP8",
99+
dtype=torch.bfloat16,
100+
num_hidden_layers=2,
101+
tie_word_embeddings=True,
102+
)
103+
model_te = NVLlamaForCausalLM(config)
104+
model_te.to(dtype=torch.float32) # I think the original llama3 model doesn't initialize in bf16.
105+
convert_llama_te_to_hf(model_te)
106+
107+
84108
@pytest.mark.skipif(os.getenv("CI", "false") == "true", reason="Skipping test in CI not download llama3 models.")
85109
@pytest.mark.parametrize(
86110
"upstream_model_name", ["meta-llama/Llama-3.2-1B-Instruct", "meta-llama/Llama-3.1-8B-Instruct"]

0 commit comments

Comments
 (0)