Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions keras_hub/src/utils/transformers/export/hf_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,32 @@
)
from keras_hub.src.utils.transformers.export.qwen import get_qwen_weights_map

# --- Qwen 3 Utils ---
from keras_hub.src.utils.transformers.export.qwen3 import get_qwen3_config
from keras_hub.src.utils.transformers.export.qwen3 import (
get_qwen3_tokenizer_config,
)
from keras_hub.src.utils.transformers.export.qwen3 import get_qwen3_weights_map

MODEL_CONFIGS = {
"GemmaBackbone": get_gemma_config,
"Gemma3Backbone": get_gemma3_config,
"QwenBackbone": get_qwen_config,
"Qwen3Backbone": get_qwen3_config,
}

MODEL_EXPORTERS = {
"GemmaBackbone": get_gemma_weights_map,
"Gemma3Backbone": get_gemma3_weights_map,
"QwenBackbone": get_qwen_weights_map,
"Qwen3Backbone": get_qwen3_weights_map,
}

MODEL_TOKENIZER_CONFIGS = {
"GemmaTokenizer": get_gemma_tokenizer_config,
"Gemma3Tokenizer": get_gemma3_tokenizer_config,
"QwenTokenizer": get_qwen_tokenizer_config,
"Qwen3Tokenizer": get_qwen3_tokenizer_config,
}


Expand Down Expand Up @@ -169,8 +179,8 @@ def export_tokenizer(tokenizer, path):
else:
warnings.warn(f"{vocab_spm_path} not found.")

# 2. BPE Models (Qwen)
elif tokenizer_type == "QwenTokenizer":
# 2. BPE Models (Qwen/Qwen3)
elif tokenizer_type in ["QwenTokenizer","Qwen3Tokenizer"]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better readability and to adhere to standard Python formatting conventions, please add a space after the comma in the list. This is a minor style issue that a formatter like ruff would typically correct.

Suggested change
elif tokenizer_type in ["QwenTokenizer","Qwen3Tokenizer"]:
elif tokenizer_type in ["QwenTokenizer", "Qwen3Tokenizer"]:
References
  1. The style guide requires using ruff for code formatting. This change aligns with common ruff configurations for list formatting. (link)

vocab_json_path = os.path.join(path, "vocabulary.json")
vocab_hf_path = os.path.join(path, "vocab.json")
if os.path.exists(vocab_json_path):
Expand Down
131 changes: 131 additions & 0 deletions keras_hub/src/utils/transformers/export/qwen3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import keras.ops as ops


def get_qwen3_config(backbone):
"""Convert Keras Qwen3 config to Hugging Face Qwen2Config."""
# Qwen3 uses the Qwen2 architecture (RoPE, SwiGLU, RMSNorm)
cfg = backbone.get_config()

return {
# Core dimensions
"vocab_size": cfg["vocabulary_size"],
"hidden_size": cfg["hidden_dim"],
"num_hidden_layers": cfg["num_layers"],
"num_attention_heads": cfg["num_query_heads"],
"num_key_value_heads": cfg["num_key_value_heads"],
"intermediate_size": cfg["intermediate_dim"],

# Architecture details
"hidden_act": "silu",
"rms_norm_eps": cfg["layer_norm_epsilon"],
"rope_theta": cfg["rope_max_wavelength"],
"tie_word_embeddings": cfg["tie_word_embeddings"],

# Defaults
"initializer_range": 0.02,
"use_cache": True,
"attention_dropout": cfg["dropout"],
# HF uses "qwen2" model type for the Qwen family
"model_type": "qwen2",
}


def get_qwen3_weights_map(backbone, include_lm_head=False):
"""Create a weights map for a given Qwen3 model."""
weights_map = {}

# 1. Embeddings
weights_map["model.embed_tokens.weight"] = backbone.get_layer(
"token_embedding"
).embeddings

for i in range(backbone.num_layers):
decoder_layer = backbone.get_layer(f"transformer_layer_{i}")

# KerasHub Qwen3LayerNorm uses 'scale'
weights_map[f"model.layers.{i}.input_layernorm.weight"] = (
decoder_layer._self_attention_layernorm.scale
)

weights_map[f"model.layers.{i}.post_attention_layernorm.weight"] = (
decoder_layer._feedforward_layernorm.scale
)

# --- Attention ---
attn_layer = decoder_layer._self_attention_layer

# Helper to map QKV (Reshape -> Transpose -> Bias)
def map_qkv(keras_layer, hf_name):
# Kernel: (Hidden, Heads, Dim) -> (Hidden, Heads*Dim)
# -> (Heads*Dim, Hidden)
k = ops.reshape(keras_layer.kernel, (backbone.hidden_dim, -1))
weights_map[f"model.layers.{i}.self_attn.{hf_name}.weight"] = ops.transpose(k)

# Bias: (Heads, Dim) -> (Heads*Dim)
# Qwen usually includes biases for Q, K, V
if keras_layer.bias is not None:
b = ops.reshape(keras_layer.bias, (-1,))
weights_map[f"model.layers.{i}.self_attn.{hf_name}.bias"] = b

# Access sub-layers (Robust check for _underscore vs public)
# Based on Qwen2, these are usually _query_dense, etc.
q_layer = getattr(attn_layer, "query_dense", getattr(attn_layer, "_query_dense", None))
k_layer = getattr(attn_layer, "key_dense", getattr(attn_layer, "_key_dense", None))
v_layer = getattr(attn_layer, "value_dense", getattr(attn_layer, "_value_dense", None))
o_layer = getattr(attn_layer, "output_dense", getattr(attn_layer, "_output_dense", None))

if q_layer: map_qkv(q_layer, "q_proj")
if k_layer: map_qkv(k_layer, "k_proj")
if v_layer: map_qkv(v_layer, "v_proj")

# Output (O_Proj) - Qwen usually has NO BIAS on output
if o_layer:
# Kernel: (Heads, Dim, Hidden) -> (Heads*Dim, Hidden) ->
# (Hidden, Heads*Dim)
o_k = ops.reshape(o_layer.kernel, (-1, backbone.hidden_dim))
weights_map[f"model.layers.{i}.self_attn.o_proj.weight"] = ops.transpose(o_k)

# --- MLP (SwiGLU) ---
# Gate (With activation)
gate_w = decoder_layer._feedforward_gate_dense.kernel
weights_map[f"model.layers.{i}.mlp.gate_proj.weight"] = ops.transpose(gate_w)

# Up (Intermediate)
up_w = decoder_layer._feedforward_intermediate_dense.kernel
weights_map[f"model.layers.{i}.mlp.up_proj.weight"] = ops.transpose(up_w)

# Down (Output)
down_w = decoder_layer._feedforward_output_dense.kernel
weights_map[f"model.layers.{i}.mlp.down_proj.weight"] = ops.transpose(down_w)

# Final Norm
weights_map["model.norm.weight"] = backbone.get_layer(
"sequence_output_layernorm"
).scale

# LM Head
if include_lm_head:
if backbone.tie_word_embeddings:
# If tied, point to input embeddings (Exporter handles cloning)
weights_map["lm_head.weight"] = weights_map["model.embed_tokens.weight"]
else:
lm_head_w = backbone.get_layer("token_embedding").reverse_embeddings
# HF expects (Vocab, Hidden). Keras ReversibleEmbedding s
# tores (Vocab, Hidden).
# No transpose needed usually, but check if your version differs.
weights_map["lm_head.weight"] = lm_head_w

return weights_map


def get_qwen3_tokenizer_config(tokenizer):
"""Convert Keras Qwen3 tokenizer config to Hugging Face."""
return {
"tokenizer_class": "Qwen2Tokenizer",
"bos_token": None,
"eos_token": "<|endoftext|>",
"pad_token": "<|endoftext|>",
"unk_token": None,
"model_max_length": 32768,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this true for all the model variations?

}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Please add a newline at the end of the file. It's a standard convention for text files and is often enforced by formatters like ruff to prevent issues with file concatenation and some tools.

References
  1. The style guide requires using ruff for code formatting. Standard ruff configurations enforce a final newline in files. (link)

118 changes: 118 additions & 0 deletions keras_hub/src/utils/transformers/export/qwen3_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import os
import json
import numpy as np
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer

import keras.ops as ops
from keras_hub.src.models.qwen3.qwen3_backbone import Qwen3Backbone
from keras_hub.src.models.qwen3.qwen3_causal_lm import Qwen3CausalLM
from keras_hub.src.models.qwen3.qwen3_causal_lm_preprocessor import (
Qwen3CausalLMPreprocessor,
)
from keras_hub.src.models.qwen3.qwen3_tokenizer import Qwen3Tokenizer
from keras_hub.src.tests.test_case import TestCase
from keras_hub.src.utils.transformers.export.hf_exporter import (
export_to_safetensors,
)


class TestQwen3Export(TestCase):

def test_export_to_hf(self):
# 1. Setup Dummy Tokenizer Assets (BPE)
vocab = {
# Special Tokens
"<|endoftext|>": 0,
"<|im_start|>": 1,
"<|im_end|>": 2,

# Base Characters
"Ġ": 3, "q": 4, "u": 5, "i": 6, "c": 7, "k": 8,

# Merged Tokens
"qu": 9,
"ic": 10,

# Full Words
"The": 11, "quick": 12, "brown": 13, "fox": 14
}

merges = ["q u", "i c"] # Merges imply "qu" and "ic" exist

temp_dir = self.get_temp_dir()
vocab_path = os.path.join(temp_dir, "vocab.json")
merges_path = os.path.join(temp_dir, "merges.txt")

with open(vocab_path, "w") as f: json.dump(vocab, f)
with open(merges_path, "w") as f: f.write("\n".join(merges))

tokenizer = Qwen3Tokenizer(vocabulary=vocab_path, merges=merges_path)

# 2. Create Tiny Qwen3 Backbone
backbone = Qwen3Backbone(
vocabulary_size=len(vocab),
num_layers=2,
num_query_heads=4,
num_key_value_heads=2,
hidden_dim=64,
intermediate_dim=128,
head_dim=16,
rope_max_wavelength=10000,
rope_scaling_factor=1.0,
layer_norm_epsilon=1e-6,
dropout=0,
)

# 3. Create Model
preprocessor = Qwen3CausalLMPreprocessor(tokenizer=tokenizer, sequence_length=32)
keras_model = Qwen3CausalLM(backbone=backbone, preprocessor=preprocessor)

# 4. Randomize Weights
rng = np.random.default_rng(42)
weights = keras_model.get_weights()
for i in range(len(weights)):
weights[i] = rng.random(weights[i].shape).astype(weights[i].dtype)
keras_model.set_weights(weights)

# 5. Export
export_path = os.path.join(temp_dir, "export_task")
export_to_safetensors(keras_model, export_path)

# Patch config for dummy vocab compatibility
config_path = os.path.join(export_path, "config.json")
with open(config_path, "r") as f: cfg = json.load(f)
cfg["eos_token_id"] = 0
with open(config_path, "w") as f: json.dump(cfg, f, indent=2)

# 6. Load with Hugging Face (Qwen2 class works for Qwen3)
hf_model = AutoModelForCausalLM.from_pretrained(export_path, trust_remote_code=True)
hf_tokenizer = AutoTokenizer.from_pretrained(export_path)

# 7. Verify Config
hf_config = hf_model.config
self.assertEqual(hf_config.vocab_size, backbone.vocabulary_size)
self.assertEqual(hf_config.num_hidden_layers, backbone.num_layers)
self.assertEqual(hf_config.num_attention_heads, backbone.num_query_heads)
self.assertEqual(hf_config.num_key_value_heads, backbone.num_key_value_heads)
self.assertEqual(hf_config.hidden_size, backbone.hidden_dim)
self.assertEqual(hf_config.intermediate_size, backbone.intermediate_dim)

# 8. Compare Logits
# Using raw IDs to bypass tokenizer quirks with dummy vocab
input_ids = np.array([[1, 2, 4]])

keras_inputs = {
"token_ids": input_ids,
"padding_mask": np.ones_like(input_ids)
}
keras_logits = keras_model(keras_inputs)

import torch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better code organization and readability, it's a best practice to place all imports at the top of the file. Since torch is a requirement for this test to run the Hugging Face model, moving this import to the top makes the dependency explicit.

hf_inputs = {"input_ids": torch.tensor(input_ids)}
hf_logits = hf_model(**hf_inputs).logits

keras_logits_np = ops.convert_to_numpy(keras_logits)
hf_logits_np = hf_logits.detach().cpu().numpy()

self.assertAllClose(keras_logits_np, hf_logits_np, atol=1e-3, rtol=1e-3)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Please add a newline at the end of the file. It's a standard convention for text files and is often enforced by formatters like ruff to prevent issues with file concatenation and some tools.

References
  1. The style guide requires using ruff for code formatting. Standard ruff configurations enforce a final newline in files. (link)

Loading