Skip to content

Commit d96690a

Browse files
AlienKevinclaudeGemini 3 Flash
authored
[RL] Support Qwen 2.5 in RL weight transfer and model registry (#2456)
This PR adds support for Qwen 2.5 models in the RL pipeline. - Updates weight transfer logic and model mappings (handles bias keys and MHA/GQA differences). - Registers `Qwen2ForCausalLM` in the `tpu_inference` model registry to fix missing architecture errors. Fixes #2446 --------- Co-authored-by: Claude Opus 4.5 <[email protected]> Co-authored-by: Gemini 3 Flash <[email protected]>
1 parent 28a0dfe commit d96690a

File tree

3 files changed

+118
-20
lines changed

3 files changed

+118
-20
lines changed

lib/marin/src/marin/rl/environments/inference_ctx/vllm.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,25 @@ def _render_messages_to_tokens(self, messages: list[Message]) -> list[int]:
124124
"""
125125
return self.renderer.build_generation_prompt(messages)
126126

127+
@staticmethod
128+
def _patch_tpu_inference_registry():
129+
"""Register Qwen2ForCausalLM in tpu_inference if not present."""
130+
try:
131+
from tpu_inference.models.common import model_loader
132+
133+
if "Qwen2ForCausalLM" not in model_loader._MODEL_REGISTRY:
134+
logger.info("Patching tpu_inference to support Qwen2ForCausalLM")
135+
from tpu_inference.models.jax.qwen2 import Qwen2ForCausalLM
136+
137+
model_loader.register_model("Qwen2ForCausalLM", Qwen2ForCausalLM)
138+
except ImportError:
139+
logger.exception("Failed to patch tpu_inference registry")
140+
raise
141+
127142
@staticmethod
128143
def _get_llm_engine(inference_config: vLLMInferenceContextConfig):
144+
vLLMInferenceContext._patch_tpu_inference_registry()
145+
129146
if inference_config.mode == InferenceMode.SYNC:
130147
if LLM is None:
131148
raise ImportError("vLLM is not installed. Please install it with: pip install vllm")

lib/marin/src/marin/rl/environments/inference_ctx/vllm_utils.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,18 @@ def levanter_qwen_to_vllm_mapping():
6666
{
6767
"model.layers.*.self_attn.q_norm": ("model.layers.*.self_attn.q_norm.scale", (None,)),
6868
"model.layers.*.self_attn.k_norm": ("model.layers.*.self_attn.k_norm.scale", (None,)),
69+
"model.layers.*.self_attn.q_proj_bias": (
70+
"model.layers.*.self_attn.q_proj.bias",
71+
("model", None),
72+
),
73+
"model.layers.*.self_attn.k_proj_bias": (
74+
"model.layers.*.self_attn.k_proj.bias",
75+
("model", None),
76+
),
77+
"model.layers.*.self_attn.v_proj_bias": (
78+
"model.layers.*.self_attn.v_proj.bias",
79+
("model", None),
80+
),
6981
}
7082
)
7183
return mapping
@@ -80,9 +92,12 @@ def levanter_qwen_to_vllm_mapping():
8092
"k_proj": (2, 0, 1),
8193
"v_proj": (2, 0, 1),
8294
"o_proj": (1, 2, 0),
95+
"q_proj_bias": (0, 1),
96+
"k_proj_bias": (0, 1),
97+
"v_proj_bias": (0, 1),
8398
}
8499

85-
MODEL_MAPPINGS = {
100+
_MODEL_MAPPINGS = {
86101
"meta-llama/Llama-3.2-1B-Instruct": levanter_llama_to_vllm_mapping(),
87102
"meta-llama/Llama-3.2-3B-Instruct": levanter_llama_to_vllm_mapping(),
88103
"Qwen/Qwen3-0.6B": levanter_qwen_to_vllm_mapping(),
@@ -92,7 +107,7 @@ def levanter_qwen_to_vllm_mapping():
92107
"marin-community/marin-8b-instruct": levanter_llama_to_vllm_mapping(),
93108
}
94109

95-
MODEL_TRANSPOSE_KEYS = {
110+
_MODEL_TRANSPOSE_KEYS = {
96111
"meta-llama/Llama-3.2-1B-Instruct": llama_transpose_keys,
97112
"meta-llama/Llama-3.2-3B-Instruct": llama_transpose_keys,
98113
"Qwen/Qwen3-0.6B": llama_transpose_keys,
@@ -101,3 +116,42 @@ def levanter_qwen_to_vllm_mapping():
101116
"Qwen/Qwen3-8B": llama_transpose_keys,
102117
"marin-community/marin-8b-instruct": llama_transpose_keys,
103118
}
119+
120+
121+
def _infer_mapping(model_name: str) -> dict:
122+
"""Infer the vLLM mapping for a model name, falling back to substring matching."""
123+
if model_name in _MODEL_MAPPINGS:
124+
return _MODEL_MAPPINGS[model_name]
125+
if "Qwen2.5" in model_name:
126+
return levanter_qwen_to_vllm_mapping()
127+
raise KeyError(f"No MODEL_MAPPING registered for model: {model_name}")
128+
129+
130+
def _infer_transpose_keys(model_name: str) -> dict:
131+
"""Infer the transpose keys for a model name, falling back to substring matching."""
132+
if model_name in _MODEL_TRANSPOSE_KEYS:
133+
return _MODEL_TRANSPOSE_KEYS[model_name]
134+
if "Qwen2.5" in model_name:
135+
return llama_transpose_keys
136+
raise KeyError(f"No MODEL_TRANSPOSE_KEYS registered for model: {model_name}")
137+
138+
139+
class _FallbackDict:
140+
"""Dict-like object that supports fallback lookup by substring matching."""
141+
142+
def __init__(self, fallback):
143+
self._fallback = fallback
144+
145+
def __getitem__(self, key):
146+
return self._fallback(key)
147+
148+
def __contains__(self, key):
149+
try:
150+
self._fallback(key)
151+
return True
152+
except KeyError:
153+
return False
154+
155+
156+
MODEL_MAPPINGS = _FallbackDict(_infer_mapping)
157+
MODEL_TRANSPOSE_KEYS = _FallbackDict(_infer_transpose_keys)

lib/marin/src/marin/rl/weight_utils.py

Lines changed: 45 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,18 @@
1818
from levanter.models.lm_model import LmHeadModel
1919

2020

21+
def _get_nnx_key_name(split_key: list[str]) -> str:
22+
"""
23+
Determine the NNX key name from the split Levanter key.
24+
If the key ends in 'bias', append '_bias' to the parameter name.
25+
Otherwise (e.g. 'weight'), use the parameter name directly.
26+
"""
27+
key_name = split_key[-2]
28+
if split_key[-1] == "bias":
29+
key_name = f"{key_name}_bias"
30+
return key_name
31+
32+
2133
def levanter_to_nnx_state(levanter_model: LmHeadModel) -> dict:
2234
# The format of this state dict is flat like:
2335
# model.layers.0.self_attn.q_proj.weight -> jax array
@@ -46,7 +58,7 @@ def levanter_to_nnx_state(levanter_model: LmHeadModel) -> dict:
4658
# vLLM expects the weights to be padded to the next multiple of 128. I assume this is
4759
# because they want to use Pallas kernels which have this requirement.
4860
if "self_attn" in split_key_without_weight:
49-
if "q_proj" in split_key_without_weight:
61+
if "q_proj" in split_key_without_weight and len(value.shape) == 4:
5062
kv_heads, q_heads_per_group, head_size, embed = value.shape
5163
value = value.reshape(kv_heads * q_heads_per_group, head_size, embed)
5264

@@ -67,7 +79,7 @@ def levanter_to_nnx_state(levanter_model: LmHeadModel) -> dict:
6779
# pad 3rd dimension to 128 (e.g., (8, 2048, 64) -> (8, 2048, 128))
6880
value = jnp.pad(value, ((0, 0), (0, 0), (0, next_multiple_of_128 - head_size)))
6981

70-
current[split_key_without_weight[-1]] = nnx.Param(value)
82+
current[_get_nnx_key_name(split_key)] = nnx.Param(value)
7183
return nnx.State(nested_state_dict)
7284

7385

@@ -89,31 +101,46 @@ def levanter_state_dict_to_nnx_state_on_cpu(state_dict: dict) -> dict:
89101
current[part] = {}
90102
current = current[part]
91103

92-
# for q, k, v projections, we need to pad the 2nd dimension to next multiple of 128
93-
# vLLM expects the weights to be padded to the next multiple of 128. I assume this is
94-
# because they want to use Pallas kernels which have this requirement.
104+
# vLLM requires weights/biases to be padded to the nearest multiple of 128 for Pallas kernels.
95105
if "self_attn" in split_key_without_weight:
106+
is_bias = split_key[-1] == "bias"
107+
108+
# Flatten grouped query heads -> (Total Heads, Head Dim, [Embed]) for vLLM
96109
if "q_proj" in split_key_without_weight:
97-
kv_heads, q_heads_per_group, head_size, embed = value.shape
98-
value = value.reshape(kv_heads * q_heads_per_group, head_size, embed)
110+
if len(value.shape) == 4:
111+
# Weight: (KV, Group, HeadSize, Embed) -> (Heads, HeadSize, Embed)
112+
kv_heads, q_heads_per_group, head_size, embed = value.shape
113+
value = value.reshape(kv_heads * q_heads_per_group, head_size, embed)
114+
elif len(value.shape) == 3 and is_bias:
115+
# Bias: (KV, Group, HeadSize) -> (Heads, HeadSize)
116+
kv_heads, q_heads_per_group, head_size = value.shape
117+
value = value.reshape(kv_heads * q_heads_per_group, head_size)
99118

119+
# Pad the head dimension (dim 1) for Q/K/V projections
100120
if (
101121
"q_proj" in split_key_without_weight
102122
or "k_proj" in split_key_without_weight
103123
or "v_proj" in split_key_without_weight
104124
):
105-
_heads, head_size, embed = value.shape
106-
next_multiple_of_128 = ((head_size + 127) // 128) * 128
107-
if head_size < next_multiple_of_128:
108-
# pad 2nd dimension to 128 (e.g., (8, 64, 2048) -> (8, 128, 2048))
109-
value = jnp.pad(value, ((0, 0), (0, next_multiple_of_128 - head_size), (0, 0)))
125+
pad_axis = 1
126+
if len(value.shape) >= 2:
127+
head_size = value.shape[pad_axis]
128+
next_multiple_of_128 = ((head_size + 127) // 128) * 128
129+
130+
if head_size < next_multiple_of_128:
131+
padding = [(0, 0)] * len(value.shape)
132+
padding[pad_axis] = (0, next_multiple_of_128 - head_size)
133+
value = jnp.pad(value, padding)
134+
135+
# Pad o_proj weights along the head dimension (dim 2)
110136
elif "o_proj" in split_key_without_weight:
111-
embed, _heads, head_size = value.shape
112-
next_multiple_of_128 = ((head_size + 127) // 128) * 128
113-
if head_size < next_multiple_of_128:
114-
# pad 3rd dimension to 128 (e.g., (8, 2048, 64) -> (8, 2048, 128))
115-
value = jnp.pad(value, ((0, 0), (0, 0), (0, next_multiple_of_128 - head_size)))
137+
# Weight: (Embed, Heads, HeadSize). Skip bias as it is 1D (Embed,) or handled differently.
138+
if not is_bias and len(value.shape) == 3:
139+
embed, _heads, head_size = value.shape
140+
next_multiple_of_128 = ((head_size + 127) // 128) * 128
141+
if head_size < next_multiple_of_128:
142+
value = jnp.pad(value, ((0, 0), (0, 0), (0, next_multiple_of_128 - head_size)))
116143

117-
current[split_key_without_weight[-1]] = nnx.Param(value)
144+
current[_get_nnx_key_name(split_key)] = nnx.Param(value)
118145

119146
return nnx.State(nested_state_dict)

0 commit comments

Comments
 (0)