Skip to content

Commit a984103

Browse files
authored
update qwen conversion script (#2207)
1 parent 2036b54 commit a984103

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

tools/checkpoint_conversion/convert_qwen_checkpoints.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def main(_):
106106
hf_tokenizer = AutoTokenizer.from_pretrained(hf_preset, return_tensors="pt")
107107
hf_model.eval()
108108

109-
keras_hub_model = keras_hub.models.QwenBackbone.from_preset(
109+
keras_hub_backbone = keras_hub.models.QwenBackbone.from_preset(
110110
f"hf://{hf_preset}"
111111
)
112112
keras_hub_tokenizer = keras_hub.models.QwenTokenizer.from_preset(
@@ -117,9 +117,18 @@ def main(_):
117117

118118
# === Check that the models and tokenizers outputs match ===
119119
test_tokenizer(keras_hub_tokenizer, hf_tokenizer)
120-
test_model(keras_hub_model, keras_hub_tokenizer, hf_model, hf_tokenizer)
120+
test_model(keras_hub_backbone, keras_hub_tokenizer, hf_model, hf_tokenizer)
121121
print("\n-> Tests passed!")
122122

123+
preprocessor = keras_hub.models.Qwen2CausalLMPreprocessor(
124+
keras_hub_tokenizer
125+
)
126+
keras_hub_model = keras_hub.models.Qwen2CausalLM(
127+
keras_hub_backbone, preprocessor
128+
)
129+
130+
keras_hub_model.save_to_preset(f"./{preset}")
131+
123132

124133
if __name__ == "__main__":
125134
flags.mark_flag_as_required("preset")

0 commit comments

Comments
 (0)