Skip to content

Commit 73691fa

Browse files
authored
Add Qwen3 models (#11)
* qwen reranker * encoder * make qwen3forcausallm -> qwen3model for embedding * pre-commit * fix running script * add note * add comment * adhoc for pooling configuration in qwen3 embedding model * pre-commit
1 parent 0619506 commit 73691fa

6 files changed

Lines changed: 341 additions & 12 deletions

File tree

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# Copyright 2025 Rebellions Inc. All rights reserved.
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at:
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import asyncio
16+
17+
import fire
18+
import torch
19+
from vllm import AsyncEngineArgs, AsyncLLMEngine, PoolingParams
20+
21+
22+
def get_input_prompts() -> list[str]:
23+
24+
def get_detailed_instruct(task_description: str, query: str) -> str:
25+
return f'Instruct: {task_description}\nQuery:{query}'
26+
27+
# Each query must come with a one-sentence instruction
28+
# that describes the task
29+
task = ('Given a web search query, '
30+
'retrieve relevant passages that answer the query')
31+
32+
queries = [
33+
get_detailed_instruct(task, 'What is the capital of China?'),
34+
get_detailed_instruct(task, 'Explain gravity')
35+
]
36+
documents = [
37+
"The capital of China is Beijing.",
38+
("Gravity is a force that attracts two bodies towards each other. "
39+
"It gives weight to physical objects and "
40+
"is responsible for the movement of planets around the sun.")
41+
]
42+
43+
inputs_texts = queries + documents
44+
return inputs_texts
45+
46+
47+
async def embed(engine: AsyncLLMEngine, prompt: str, model: str,
48+
requst_id: int):
49+
print(f"embed request_id={requst_id}, prompt={prompt}")
50+
pooling_params = PoolingParams()
51+
results_generator = engine.encode(
52+
prompt,
53+
pooling_params,
54+
str(requst_id),
55+
)
56+
57+
# get the results
58+
final_output = None
59+
async for request_output in results_generator:
60+
final_output = request_output
61+
return final_output
62+
63+
64+
async def main(
65+
batch_size: int,
66+
max_seq_len: int,
67+
kvcache_block_size: int,
68+
num_input_prompt: int,
69+
model_id: str,
70+
):
71+
engine_args = AsyncEngineArgs(model=model_id,
72+
device="auto",
73+
max_num_seqs=batch_size,
74+
max_num_batched_tokens=max_seq_len,
75+
max_model_len=max_seq_len,
76+
block_size=kvcache_block_size,
77+
task="embed")
78+
79+
engine = AsyncLLMEngine.from_engine_args(engine_args)
80+
prompt_list = get_input_prompts()
81+
if len(prompt_list) > 2 * num_input_prompt:
82+
raise RuntimeError("The len(QUERIES) and len(DOCUMENTS) ",
83+
"should be equal with 2 * `num_input_prompt`.")
84+
futures = []
85+
for i, p in enumerate(prompt_list):
86+
if i == num_input_prompt * 2:
87+
break
88+
futures.append(
89+
asyncio.create_task(
90+
embed(
91+
engine,
92+
prompt=p,
93+
model=model_id,
94+
requst_id=i,
95+
)))
96+
97+
outputs = await asyncio.gather(*futures)
98+
99+
embeddings = torch.stack([o.outputs.data for o in outputs])
100+
scores = (embeddings[:num_input_prompt] @ embeddings[num_input_prompt:].T)
101+
102+
print(f"scores: {scores.tolist()}")
103+
104+
105+
def entry_point(
106+
batch_size: int = 1,
107+
max_seq_len: int = 32768,
108+
kvcache_block_size: int = 32768,
109+
num_input_prompt: int = 2,
110+
model_id: str = "/qwen3-0.6b-b1-embedding",
111+
):
112+
loop = asyncio.get_event_loop()
113+
loop.run_until_complete(
114+
main(
115+
batch_size=batch_size,
116+
max_seq_len=max_seq_len,
117+
kvcache_block_size=kvcache_block_size,
118+
num_input_prompt=num_input_prompt,
119+
model_id=model_id,
120+
))
121+
122+
123+
if __name__ == "__main__":
124+
fire.Fire(entry_point)
Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
# Copyright 2025 Rebellions Inc. All rights reserved.
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at:
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import asyncio
16+
import math
17+
18+
import fire
19+
from transformers import AutoTokenizer
20+
from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams
21+
from vllm.inputs.data import TokensPrompt
22+
23+
SUFFIX = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
24+
25+
26+
def format_instruction(instruction, query, doc):
27+
text = [{
28+
"role":
29+
"system",
30+
"content": ("Judge whether the Document meets the requirements "
31+
"based on the Query and the Instruct provided. "
32+
"Note that the answer can only be \"yes\" or \"no\".")
33+
}, {
34+
"role":
35+
"user",
36+
"content":
37+
f"<Instruct>: {instruction}\n\n<Query>: {query}\n\n<Document>: {doc}"
38+
}]
39+
return text
40+
41+
42+
def process_inputs(pairs, instruction, max_length, suffix_tokens, tokenizer):
43+
messages = [
44+
format_instruction(instruction, query, doc) for query, doc in pairs
45+
]
46+
messages = tokenizer.apply_chat_template(messages,
47+
tokenize=True,
48+
add_generation_prompt=False,
49+
enable_thinking=False)
50+
messages = [ele[:max_length] + suffix_tokens for ele in messages]
51+
messages = [TokensPrompt(prompt_token_ids=ele) for ele in messages]
52+
return messages
53+
54+
55+
def get_input_prompts(model_id, max_length, suffix_tokens,
56+
tokenizer) -> list[str]:
57+
task = ('Given a web search query, '
58+
'retrieve relevant passages that answer the query')
59+
queries = [
60+
"What is the capital of China?",
61+
"Explain gravity",
62+
]
63+
documents = [
64+
"The capital of China is Beijing.",
65+
("Gravity is a force that attracts two bodies towards each other. "
66+
"It gives weight to physical objects and "
67+
"is responsible for the movement of planets around the sun.")
68+
]
69+
70+
pairs = list(zip(queries, documents))
71+
inputs = process_inputs(pairs, task, max_length - len(suffix_tokens),
72+
suffix_tokens, tokenizer)
73+
74+
return inputs
75+
76+
77+
async def generate(engine: AsyncLLMEngine, prompt_tokens: list[int],
78+
model: str, requst_id: int, true_token: int,
79+
false_token: int):
80+
print(f"generate request_id={requst_id}, prompt_tokens={prompt_tokens}")
81+
example_input = {
82+
"stream": True,
83+
"temperature": 0.0,
84+
"request_id": str(requst_id),
85+
}
86+
87+
sampling_params = SamplingParams(
88+
temperature=0,
89+
max_tokens=1,
90+
logprobs=20,
91+
allowed_token_ids=[true_token, false_token],
92+
)
93+
94+
results_generator = engine.generate(
95+
prompt_tokens,
96+
sampling_params,
97+
example_input["request_id"],
98+
)
99+
100+
# get the results
101+
final_output = None
102+
async for request_output in results_generator:
103+
final_output = request_output
104+
return final_output
105+
106+
107+
def compute_logits(outputs, true_token, false_token):
108+
scores = []
109+
for i in range(len(outputs)):
110+
final_logits = outputs[i].outputs[0].logprobs[-1]
111+
if true_token not in final_logits:
112+
true_logit = -10
113+
else:
114+
true_logit = final_logits[true_token].logprob
115+
if false_token not in final_logits:
116+
false_logit = -10
117+
else:
118+
false_logit = final_logits[false_token].logprob
119+
true_score = math.exp(true_logit)
120+
false_score = math.exp(false_logit)
121+
score = true_score / (true_score + false_score)
122+
scores.append(score)
123+
return scores
124+
125+
126+
async def main(
127+
batch_size: int,
128+
max_seq_len: int,
129+
kvcache_block_size: int,
130+
num_input_prompt: int,
131+
model_id: str,
132+
):
133+
engine_args = AsyncEngineArgs(model=model_id,
134+
device="auto",
135+
max_num_seqs=batch_size,
136+
max_num_batched_tokens=max_seq_len,
137+
max_model_len=max_seq_len,
138+
block_size=kvcache_block_size)
139+
tokenizer = AutoTokenizer.from_pretrained(model_id)
140+
tokenizer.padding_side = "left"
141+
tokenizer.pad_token = tokenizer.eos_token
142+
suffix_tokens = tokenizer.encode(SUFFIX, add_special_tokens=False)
143+
144+
true_token = tokenizer("yes", add_special_tokens=False).input_ids[0]
145+
false_token = tokenizer("no", add_special_tokens=False).input_ids[0]
146+
147+
engine = AsyncLLMEngine.from_engine_args(engine_args)
148+
prompt_tokens_list = get_input_prompts(model_id, max_seq_len,
149+
suffix_tokens, tokenizer)
150+
futures = []
151+
for i, p in enumerate(prompt_tokens_list):
152+
if i == num_input_prompt:
153+
break
154+
155+
futures.append(
156+
asyncio.create_task(
157+
generate(engine,
158+
prompt_tokens=p,
159+
model=model_id,
160+
requst_id=i,
161+
true_token=true_token,
162+
false_token=false_token)))
163+
164+
result = await asyncio.gather(*futures)
165+
score = compute_logits(result, true_token, false_token)
166+
print(f"scores: {score}")
167+
168+
169+
def entry_point(
170+
batch_size: int = 1,
171+
max_seq_len: int = 32768,
172+
kvcache_block_size: int = 32768,
173+
num_input_prompt: int = 2,
174+
model_id: str = "/qwen3-0.6b-b1",
175+
):
176+
loop = asyncio.get_event_loop()
177+
loop.run_until_complete(
178+
main(
179+
batch_size=batch_size,
180+
max_seq_len=max_seq_len,
181+
kvcache_block_size=kvcache_block_size,
182+
num_input_prompt=num_input_prompt,
183+
model_id=model_id,
184+
))
185+
186+
187+
if __name__ == "__main__":
188+
fire.Fire(entry_point)

vllm_rbln/model_executor/models/optimum/__init__.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ def load_model(
6969

7070

7171
__all__ = [
72-
"load_model",
73-
"get_rbln_model_info",
74-
"ModelInputForRBLN",
72+
"load_model", "get_rbln_model_info", "ModelInputForRBLN",
73+
"RBLNOptimumForEncoderModel"
7574
]

vllm_rbln/model_executor/models/optimum/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
"ExaoneForCausalLM": ("exaone", "RBLNExaoneForCausalLM"),
3939
"Qwen2ForCausalLM": ("qwen2", "RBLNQwen2ForCausalLM"),
4040
"OPTForCausalLM": ("opt", "RBLNOPTForCausalLM"),
41+
"Qwen3ForCausalLM": ("qwen3", "RBLNQwen3ForCausalLM"),
4142
}
4243

4344
_RBLN_ENCODER_DECODER_MODELS: Dict[str, Tuple[str, str]] = {
@@ -71,6 +72,7 @@
7172
"XLMRobertaForSequenceClassification":
7273
("xlm_roberta_classification", "RBLNXLMRobertaForSequenceClassification"),
7374
"XLMRobertaModel": ("xlm_roberta", "RBLNXLMRobertaModel"),
75+
"Qwen3Model": ("qwen3", "RBLNQwen3Model"),
7476
}
7577

7678
_RBLN_SUPPORTED_MODELS = {

vllm_rbln/model_executor/models/optimum/encoder.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,14 +114,17 @@ def forward(self, model_input: ModelInputForRBLN) -> torch.Tensor:
114114
if token_type_ids:
115115
kwargs["token_type_ids"] = token_type_ids
116116
else:
117-
rbln_model_input_names = self.rbln_model_config.model_input_names
118-
if "token_type_ids" in rbln_model_input_names:
119-
kwargs["token_type_ids"] = torch.zeros_like(input_ids)
117+
model_input_names = getattr(self.rbln_model_config,
118+
"model_input_names", None)
119+
if model_input_names is not None:
120+
rbln_model_input_names = \
121+
self.rbln_model_config.model_input_names
122+
if "token_type_ids" in rbln_model_input_names:
123+
kwargs["token_type_ids"] = torch.zeros_like(input_ids)
120124

121125
embeds = self.model.forward(**kwargs)
122126

123127
hidden_states = embeds[0]
124-
125128
if isinstance(hidden_states, tuple):
126129
hidden_states = hidden_states[0]
127130

0 commit comments

Comments
 (0)