Skip to content

Commit 2036b54

Browse files
Qwen - Fix Preset Loader + Add Causal LM Test (#2193)
* load tie embedding param from config * add causal lm test for qwen + bug fix * address comment
1 parent 003f897 commit 2036b54

File tree

3 files changed

+125
-1
lines changed

3 files changed

+125
-1
lines changed

keras_hub/src/models/qwen/qwen_backbone.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def __init__(
168168
self.layer_norm_epsilon = layer_norm_epsilon
169169
self.dropout = dropout
170170
self.tie_word_embeddings = tie_word_embeddings
171-
self.use_sliding_window_attention = (use_sliding_window_attention,)
171+
self.use_sliding_window_attention = use_sliding_window_attention
172172
self.sliding_window_size = sliding_window_size
173173

174174
def get_config(self):
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
from unittest.mock import patch
2+
3+
import pytest
4+
from keras import ops
5+
6+
from keras_hub.src.models.qwen.qwen_backbone import QwenBackbone
7+
from keras_hub.src.models.qwen.qwen_causal_lm import QwenCausalLM
8+
from keras_hub.src.models.qwen.qwen_causal_lm_preprocessor import (
9+
QwenCausalLMPreprocessor,
10+
)
11+
from keras_hub.src.models.qwen.qwen_tokenizer import QwenTokenizer
12+
from keras_hub.src.tests.test_case import TestCase
13+
14+
15+
class QwenCausalLMTest(TestCase):
16+
def setUp(self):
17+
self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"]
18+
self.vocab += ["<|endoftext|>"]
19+
self.vocab += ["<|eot_id|>"]
20+
self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
21+
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
22+
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
23+
self.merges += ["Ġai r", "Ġa i", "pla ne"]
24+
self.preprocessor = QwenCausalLMPreprocessor(
25+
QwenTokenizer(vocabulary=self.vocab, merges=self.merges),
26+
sequence_length=7,
27+
)
28+
self.backbone = QwenBackbone(
29+
vocabulary_size=self.preprocessor.tokenizer.vocabulary_size(),
30+
num_layers=2,
31+
num_query_heads=4,
32+
num_key_value_heads=2,
33+
hidden_dim=8,
34+
intermediate_dim=16,
35+
)
36+
self.init_kwargs = {
37+
"preprocessor": self.preprocessor,
38+
"backbone": self.backbone,
39+
}
40+
self.train_data = ([" airplane at airport", " airplane at airport"],)
41+
self.input_data = self.preprocessor(*self.train_data)[0]
42+
43+
def test_causal_lm_basics(self):
44+
self.run_task_test(
45+
cls=QwenCausalLM,
46+
init_kwargs=self.init_kwargs,
47+
train_data=self.train_data,
48+
expected_output_shape=(2, 7, 8),
49+
)
50+
51+
def test_generate(self):
52+
causal_lm = QwenCausalLM(**self.init_kwargs)
53+
# String input.
54+
prompt = " airplane at airport"
55+
output = causal_lm.generate(" airplane at airport")
56+
self.assertTrue(prompt in output)
57+
# Int tensor input.
58+
prompt_ids = self.preprocessor.generate_preprocess([prompt])
59+
causal_lm.preprocessor = None
60+
outputs = causal_lm.generate(prompt_ids, stop_token_ids=None)
61+
# Assert prompt is in output in token id space.
62+
self.assertAllEqual(
63+
outputs["token_ids"][:, :5],
64+
prompt_ids["token_ids"][:, :5],
65+
)
66+
self.assertAllEqual(
67+
outputs["padding_mask"][:, :5],
68+
prompt_ids["padding_mask"][:, :5],
69+
)
70+
71+
def test_generate_strip_prompt(self):
72+
causal_lm = QwenCausalLM(**self.init_kwargs)
73+
prompt = " airplane at airport"
74+
output = causal_lm.generate(prompt, strip_prompt=True)
75+
self.assertFalse(output.startswith(prompt))
76+
77+
def test_early_stopping(self):
78+
causal_lm = QwenCausalLM(**self.init_kwargs)
79+
call_with_cache = causal_lm.call_with_cache
80+
81+
def wrapper(*args, **kwargs):
82+
"""Modify output logits to always favor end_token_id"""
83+
logits, hidden_states, cache = call_with_cache(*args, **kwargs)
84+
index = self.preprocessor.tokenizer.end_token_id
85+
update = ops.ones_like(logits)[:, :, index] * 1.0e9
86+
update = ops.expand_dims(update, axis=-1)
87+
logits = ops.slice_update(logits, (0, 0, index), update)
88+
return logits, hidden_states, cache
89+
90+
with patch.object(causal_lm, "call_with_cache", wraps=wrapper):
91+
prompt = [" airplane at airport", " airplane"]
92+
output = causal_lm.generate(prompt)
93+
# We should immediately abort and output the prompt.
94+
self.assertEqual(prompt, output)
95+
96+
def test_generate_compilation(self):
97+
causal_lm = QwenCausalLM(**self.init_kwargs)
98+
# Assert we do not recompile with successive calls.
99+
causal_lm.generate(" airplane at airport")
100+
first_fn = causal_lm.generate_function
101+
causal_lm.generate(" airplane at airport")
102+
second_fn = causal_lm.generate_function
103+
self.assertEqual(first_fn, second_fn)
104+
# Assert we do recompile after compile is called.
105+
causal_lm.compile(sampler="greedy")
106+
self.assertIsNone(causal_lm.generate_function)
107+
108+
@pytest.mark.large
109+
def test_saved_model(self):
110+
self.run_model_saving_test(
111+
cls=QwenCausalLM,
112+
init_kwargs=self.init_kwargs,
113+
input_data=self.input_data,
114+
)
115+
116+
@pytest.mark.extra_large
117+
def test_all_presets(self):
118+
for preset in QwenCausalLM.presets:
119+
self.run_preset_test(
120+
cls=QwenCausalLM,
121+
preset=preset,
122+
input_data=self.input_data,
123+
)

keras_hub/src/utils/transformers/convert_qwen.py

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def convert_backbone_config(transformers_config):
1818
"rope_max_wavelength": transformers_config["rope_theta"],
1919
"use_sliding_window": transformers_config["use_sliding_window"],
2020
"sliding_window_size": transformers_config["sliding_window"],
21+
"tie_word_embeddings": transformers_config["tie_word_embeddings"],
2122
}
2223

2324

0 commit comments

Comments
 (0)