Skip to content

Commit 10540dd

Browse files
authored
Initial Qwen 3 support (OpenNMT#1943)
1 parent 768e7fe commit 10540dd

2 files changed

Lines changed: 157 additions & 0 deletions

File tree

docs/guides/transformers.md

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ CTranslate2 supports selected models from Hugging Face's [Transformers](https://
2222
* GPT-NeoX
2323
* OPT
2424
* Pegasus
25+
* Qwen 2.5
26+
* Qwen 3
2527
* T5
2628
* Whisper
2729
* XLM-RoBERTa
@@ -485,6 +487,44 @@ output = tokenizer.decode(results[0].sequences_ids[0])
485487
print(output)
486488
```
487489

490+
## Qwen 3
491+
492+
[Qwen 3](https://github.com/QwenLM/Qwen3) are a collection of large language models developed by the Alibaba Group. A key feature is allows switching between "thinking mode" for complex reasoning and a "non-thinking mode" for efficient general chat.
493+
494+
To convert a model:
495+
496+
```bash
497+
ct2-transformers-converter --model Qwen/Qwen3-4B --quantization float16 --output_dir qwen3-4b-ct2
498+
```
499+
500+
Usage Sample
501+
502+
You can use the converted model for text generation with ctranslate2.Generator. For Qwen 3 instruction-tuned models, you should use the Hugging Face tokenizer's apply_chat_template method to correctly format your prompts, especially when dealing with the optional "thinking mode". Currently MoE models variants are not supported.
503+
504+
```python
505+
import ctranslate2
506+
import transformers
507+
508+
generator = ctranslate2.Generator("qwen3-4b-ct2")
509+
tokenizer = transformers.AutoTokenizer.from_pretrained("Qwen/Qwen3-4B")
510+
511+
def generate(prompt):
512+
tokens = tokenizer.convert_ids_to_tokens(tokenizer.encode(prompt, add_special_tokens=False))
513+
results = generator.generate_batch([tokens], max_length=2048, sampling_temperature=0.7, include_prompt_in_result=False)
514+
return tokenizer.decode(results[0].sequences_ids[0])
515+
516+
prompt_base = """<|im_start|>user
517+
A train leaves Station A at 60 mph heading towards Station B, 300 miles away. At the same time, another train leaves Station B at 40 mph heading towards Station A. When will they meet and how far from Station A?
518+
<|im_end|>
519+
<|im_start|>assistant"""
520+
521+
print("Non-thinking:\n" + "-"*60)
522+
print(generate(prompt_base + "\n<think></think>\n"))
523+
524+
print("\nThinking:\n" + "="*60)
525+
print(generate(prompt_base))
526+
```
527+
488528
## T5
489529

490530
[T5](https://huggingface.co/docs/transformers/model_doc/t5) is an encoder-decoder model pre-trained on a multi-task mixture of unsupervised and supervised tasks and for which each task is converted into a text-to-text format.

python/ctranslate2/converters/transformers.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2260,6 +2260,123 @@ def set_decoder(self, spec, module):
22602260
gc.collect()
22612261

22622262

2263+
@register_loader("Qwen3Config")
2264+
class Qwen3Loader(ModelLoader):
2265+
@property
2266+
def architecture_name(self):
2267+
return "Qwen3ForCausalLM"
2268+
2269+
def get_model_spec(self, model):
2270+
num_layers = model.config.num_hidden_layers
2271+
num_heads = model.config.num_attention_heads
2272+
num_heads_kv = getattr(model.config, "num_key_value_heads", num_heads)
2273+
head_dim = getattr(
2274+
model.config, "head_dim", model.config.hidden_size // num_heads
2275+
)
2276+
2277+
if num_heads_kv == num_heads:
2278+
num_heads_kv = None
2279+
2280+
rope_scaling = getattr(model.config, "rope_scaling", None)
2281+
if rope_scaling:
2282+
rope_type = rope_scaling.get("type") or rope_scaling["rope_type"]
2283+
rotary_scaling_type = _SUPPORTED_ROPE_SCALING.get(rope_type)
2284+
rotary_scaling_factor = rope_scaling["factor"]
2285+
if rotary_scaling_type is None:
2286+
raise NotImplementedError(
2287+
"RoPE scaling type '%s' is not yet implemented. "
2288+
"The following RoPE scaling types are currently supported: %s"
2289+
% (rope_scaling["type"], ", ".join(_SUPPORTED_ROPE_SCALING.keys()))
2290+
)
2291+
else:
2292+
rotary_scaling_type = None
2293+
rotary_scaling_factor = 1
2294+
2295+
spec = transformer_spec.TransformerDecoderModelSpec.from_config(
2296+
num_layers,
2297+
num_heads,
2298+
activation=common_spec.Activation.SWISH,
2299+
pre_norm=True,
2300+
ffn_glu=True,
2301+
rms_norm=True,
2302+
rotary_dim=model.config.head_dim,
2303+
rotary_interleave=False,
2304+
rotary_scaling_type=rotary_scaling_type,
2305+
rotary_scaling_factor=rotary_scaling_factor,
2306+
rotary_base=getattr(model.config, "rope_theta", 10000),
2307+
num_heads_kv=num_heads_kv,
2308+
head_dim=head_dim,
2309+
qk_norm=True,
2310+
)
2311+
2312+
self.set_decoder(spec.decoder, model.model)
2313+
self.set_linear(spec.decoder.projection, model.lm_head)
2314+
return spec
2315+
2316+
def get_vocabulary(self, model, tokenizer):
2317+
tokens = super().get_vocabulary(model, tokenizer)
2318+
extra_ids = model.config.vocab_size - len(tokens)
2319+
for i in range(extra_ids):
2320+
tokens.append("<extra_id_%d>" % i)
2321+
return tokens
2322+
2323+
def set_vocabulary(self, spec, tokens):
2324+
spec.register_vocabulary(tokens)
2325+
2326+
def set_config(self, config, model, tokenizer):
2327+
config.bos_token = (
2328+
tokenizer.bos_token
2329+
if tokenizer.bos_token is not None
2330+
else tokenizer.pad_token
2331+
)
2332+
config.eos_token = tokenizer.eos_token
2333+
config.unk_token = (
2334+
tokenizer.unk_token if tokenizer.unk_token is not None else ""
2335+
)
2336+
config.layer_norm_epsilon = model.config.rms_norm_eps
2337+
2338+
def set_layer_norm(self, spec, layer_norm):
2339+
spec.gamma = layer_norm.weight
2340+
2341+
def set_decoder(self, spec, module):
2342+
spec.scale_embeddings = False
2343+
self.set_embeddings(spec.embeddings, module.embed_tokens)
2344+
self.set_layer_norm(spec.layer_norm, module.norm)
2345+
2346+
for layer_idx, (layer_spec, layer) in enumerate(zip(spec.layer, module.layers)):
2347+
self.set_layer_norm(
2348+
layer_spec.self_attention.layer_norm, layer.input_layernorm
2349+
)
2350+
self.set_layer_norm(
2351+
layer_spec.ffn.layer_norm, layer.post_attention_layernorm
2352+
)
2353+
2354+
self.set_layer_norm(
2355+
layer_spec.self_attention.q_norm, layer.self_attn.q_norm
2356+
)
2357+
self.set_layer_norm(
2358+
layer_spec.self_attention.k_norm, layer.self_attn.k_norm
2359+
)
2360+
2361+
split_layers = [common_spec.LinearSpec() for _ in range(3)]
2362+
self.set_linear(split_layers[0], layer.self_attn.q_proj)
2363+
self.set_linear(split_layers[1], layer.self_attn.k_proj)
2364+
self.set_linear(split_layers[2], layer.self_attn.v_proj)
2365+
utils.fuse_linear(layer_spec.self_attention.linear[0], split_layers)
2366+
2367+
self.set_linear(
2368+
layer_spec.self_attention.linear[1],
2369+
layer.self_attn.o_proj,
2370+
)
2371+
2372+
self.set_linear(layer_spec.ffn.linear_0, layer.mlp.gate_proj)
2373+
self.set_linear(layer_spec.ffn.linear_0_noact, layer.mlp.up_proj)
2374+
self.set_linear(layer_spec.ffn.linear_1, layer.mlp.down_proj)
2375+
2376+
delattr(layer, "self_attn")
2377+
delattr(layer, "mlp")
2378+
2379+
22632380
@register_loader("MixFormerSequentialConfig")
22642381
class MixFormerSequentialLoader(ModelLoader):
22652382
@property

0 commit comments

Comments
 (0)