Skip to content

Commit af90a71

Browse files
lc5211The tunix Authors
authored andcommitted
[Tunix] Add weight mapping to vLLM tochax implementation for gemma4 models.
PiperOrigin-RevId: 896171461
1 parent 9bb12d3 commit af90a71

7 files changed

Lines changed: 565 additions & 23 deletions

File tree

tests/generate/utils_test.py

Lines changed: 158 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1403,19 +1403,175 @@ def test_sglang_jax_1d_kv_bias_alignment(self):
14031403
expected = jnp.tile(src_k_bias, 8)
14041404
self.assertTrue(jnp.allclose(result.params[src_key], expected))
14051405

1406+
def test_is_fused_path(self):
1407+
self.assertTrue(
1408+
utils.is_fused_path(
1409+
"vllm_model.language_model.model.layers.0.self_attn.qkv_proj.weight"
1410+
)
1411+
)
1412+
self.assertTrue(
1413+
utils.is_fused_path(
1414+
"vllm_model.language_model.model.layers.10.mlp.gate_up_proj.weight"
1415+
)
1416+
)
1417+
self.assertFalse(
1418+
utils.is_fused_path("model.layers.0.self_attn.q_proj.weight")
1419+
)
1420+
1421+
def test_fuse_src_to_same_tgt_params_qkv(self):
1422+
tgt_path = (
1423+
"vllm_model.language_model.model.layers.0.self_attn.qkv_proj.weight"
1424+
)
1425+
q_val = jnp.ones((8, 16, 64)) * 1.0 # (num_heads, d_model, head_dim)
1426+
kv_val = (
1427+
jnp.ones((2, 2, 16, 64)) * 2.0
1428+
) # (2, num_kv_heads, d_model, head_dim)
1429+
1430+
fuse_sources = {}
1431+
# First call with Q
1432+
utils.fuse_src_to_same_tgt_params(
1433+
q_val,
1434+
"layers.0.attn.q_einsum.w",
1435+
fuse_sources,
1436+
tgt_path,
1437+
None,
1438+
tp_size=1,
1439+
)
1440+
self.assertLen(fuse_sources[tgt_path], 1)
1441+
1442+
# Second call with KV
1443+
utils.fuse_src_to_same_tgt_params(
1444+
kv_val,
1445+
"layers.0.attn.kv_einsum.w",
1446+
fuse_sources,
1447+
tgt_path,
1448+
None,
1449+
tp_size=1,
1450+
)
1451+
1452+
# Should be fused now
1453+
self.assertLen(fuse_sources[tgt_path], 1)
1454+
fused_key = "layers.0.attn.qkv_fused"
1455+
self.assertIn(fused_key, fuse_sources[tgt_path])
1456+
1457+
fused_val = fuse_sources[tgt_path][fused_key][0]
1458+
# Expected shape: (q_per_tp + 2*kv_per_tp * head_dim, d_model)
1459+
# -> (d_model, (num_heads + 2*num_kv) * head_dim)
1460+
# transposed to ((num_heads+2*kv)*head_dim, d_model)
1461+
# q: (8, 16, 64) -> (16, 8, 64)
1462+
# kv: (2, 2, 16, 64) -> (16, 2, 2, 64) -> (16, 4, 64)
1463+
# concat(q, k, v) -> (16, 12, 64) -> (16, 768)
1464+
# transpose -> (768, 16)
1465+
self.assertLen(fused_val.shape, (768, 16))
1466+
self.assertTrue(jnp.allclose(fused_val[:512, :], 1.0)) # Q
1467+
self.assertTrue(jnp.allclose(fused_val[512:, :], 2.0)) # KV
1468+
1469+
def test_fuse_src_to_same_tgt_params_gate_up(self):
1470+
tgt_path = (
1471+
"vllm_model.language_model.model.layers.0.mlp.gate_up_proj.weight"
1472+
)
1473+
gate_val = jnp.ones((16, 32)) * 3.0 # (d_model, hidden)
1474+
up_val = jnp.ones((16, 32)) * 4.0 # (d_model, hidden)
1475+
1476+
fuse_sources = {}
1477+
utils.fuse_src_to_same_tgt_params(
1478+
gate_val,
1479+
"layers.0.mlp.gate_proj.kernel",
1480+
fuse_sources,
1481+
tgt_path,
1482+
None,
1483+
tp_size=1,
1484+
)
1485+
utils.fuse_src_to_same_tgt_params(
1486+
up_val,
1487+
"layers.0.mlp.up_proj.kernel",
1488+
fuse_sources,
1489+
tgt_path,
1490+
None,
1491+
tp_size=1,
1492+
)
1493+
1494+
fused_key = "layers.0.mlp.gate_up_fused"
1495+
fused_val = fuse_sources[tgt_path][fused_key][0]
1496+
1497+
# Hidden=32. tp_size=1. Gate and Up are stacked:
1498+
# (2*Hidden, d_model) = (64, 16)
1499+
self.assertEqual(fused_val.shape, (64, 16))
1500+
self.assertTrue(jnp.allclose(fused_val[0:32, :], 3.0)) # gate
1501+
self.assertTrue(jnp.allclose(fused_val[32:64, :], 4.0)) # up
1502+
1503+
def test_fuse_src_to_same_tgt_params_gate_up_tp2(self):
1504+
tgt_path = (
1505+
"vllm_model.language_model.model.layers.0.mlp.gate_up_proj.weight"
1506+
)
1507+
gate_val = jnp.ones((16, 32)) * 3.0
1508+
up_val = jnp.ones((16, 32)) * 4.0
1509+
1510+
fuse_sources = {}
1511+
utils.fuse_src_to_same_tgt_params(
1512+
gate_val,
1513+
"layers.0.mlp.gate_proj.kernel",
1514+
fuse_sources,
1515+
tgt_path,
1516+
None,
1517+
tp_size=2,
1518+
)
1519+
utils.fuse_src_to_same_tgt_params(
1520+
up_val,
1521+
"layers.0.mlp.up_proj.kernel",
1522+
fuse_sources,
1523+
tgt_path,
1524+
None,
1525+
tp_size=2,
1526+
)
1527+
1528+
fused_val = fuse_sources[tgt_path]["layers.0.mlp.gate_up_fused"][0]
1529+
1530+
# Hidden=32, tp_size=2. chunk_size=16.
1531+
# [gate[0:16], up[0:16], gate[16:32], up[16:32]] interleaved
1532+
self.assertEqual(fused_val.shape, (64, 16))
1533+
self.assertTrue(jnp.allclose(fused_val[0:16, :], 3.0))
1534+
self.assertTrue(jnp.allclose(fused_val[16:32, :], 4.0))
1535+
self.assertTrue(jnp.allclose(fused_val[32:48, :], 3.0))
1536+
self.assertTrue(jnp.allclose(fused_val[48:64, :], 4.0))
1537+
1538+
def test_align_shape_moe_gating_einsum(self):
1539+
val = jnp.ones((2, 2, 128, 16))
1540+
src_key = "layers.0.moe.gating_einsum"
1541+
tgt_shape = (2, 16, 256)
1542+
1543+
result = utils._align_shape(val, tgt_shape, src_key, tp_size=1)
1544+
self.assertEqual(result.shape, (2, 16, 256))
1545+
self.assertTrue(jnp.allclose(result, 1.0))
1546+
1547+
# Test with padding
1548+
val_small = jnp.ones((2, 2, 100, 16))
1549+
# chunk_size = 100. padded = 128. pad_amount = 28.
1550+
# result shape should be (2, 16, 256)
1551+
result_padded = utils._align_shape(
1552+
val_small, (2, 16, 256), src_key, tp_size=1
1553+
)
1554+
self.assertEqual(result_padded.shape, (2, 16, 256))
1555+
# Check that padded area is 0
1556+
# gate_chunks (2, 1, 100, 16) -> pad -> (2, 1, 128, 16) stack ->
1557+
# (2, 1, 2, 128, 16) -> reshape (2, 256, 16) ->
1558+
# transpose (2, 16,256) The first 100 of first 128 should be 1.
1559+
# The last 28 of first 128 should be 0.
1560+
np.testing.assert_array_equal(result_padded[0, 0, 100:128], 0.0)
1561+
np.testing.assert_array_equal(result_padded[0, 0, 0:100], 1.0)
14061562

14071563
def test_transfer_state_directly_fuses_moe_weights(self):
14081564
"""Tests that wi_0 and wi_1 are fused into wi when target expects it."""
14091565
wi_0_val = jnp.array([[1.0, 2.0], [5.0, 6.0]], dtype=jnp.float32)
14101566
wi_1_val = jnp.array([[3.0, 4.0], [7.0, 8.0]], dtype=jnp.float32)
1411-
1567+
14121568
src_state = nnx.Dict(
14131569
layers=nnx.Dict(
14141570
wi_0=nnx.Param(wi_0_val),
14151571
wi_1=nnx.Param(wi_1_val),
14161572
)
14171573
)
1418-
1574+
14191575
dst_state = nnx.Dict(
14201576
layers=nnx.Dict(
14211577
wi=nnx.Param(jnp.zeros((2, 4), dtype=jnp.float32))

tunix/generate/mappings.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,15 @@ def to_hf_mappings(cls, backend: str | None = None):
4343
)
4444
return mapping
4545

46+
@classmethod
47+
def key_reference_mappings(cls, backend: str | None = None):
48+
mapping = cls.mapping_for(backend).get('key_reference_mappings')
49+
if mapping is None:
50+
raise RuntimeError(
51+
f'{backend} key_reference_mappings missing for {cls.__name__}.'
52+
)
53+
return mapping
54+
4655
@classmethod
4756
def lora_to_hf_mappings(cls, backend: str | None = None):
4857
return cls.mapping_for(backend).get('lora_to_hf_mappings')
@@ -73,6 +82,7 @@ class MappingConfig:
7382
"""
7483

7584
to_hf_mappings: Optional[Dict[str, Any]] = None
85+
key_reference_mappings: Optional[Dict[str, Any]] = None
7686
lora_to_hf_mappings: Optional[Dict[str, Any]] = None
7787
to_hf_hook_fns: Optional[Dict[str, Any]] = None
7888
to_hf_transpose_keys: Optional[Dict[str, Tuple[int, ...]]] = None

0 commit comments

Comments
 (0)