@@ -1403,19 +1403,175 @@ def test_sglang_jax_1d_kv_bias_alignment(self):
14031403 expected = jnp .tile (src_k_bias , 8 )
14041404 self .assertTrue (jnp .allclose (result .params [src_key ], expected ))
14051405
1406+ def test_is_fused_path (self ):
1407+ self .assertTrue (
1408+ utils .is_fused_path (
1409+ "vllm_model.language_model.model.layers.0.self_attn.qkv_proj.weight"
1410+ )
1411+ )
1412+ self .assertTrue (
1413+ utils .is_fused_path (
1414+ "vllm_model.language_model.model.layers.10.mlp.gate_up_proj.weight"
1415+ )
1416+ )
1417+ self .assertFalse (
1418+ utils .is_fused_path ("model.layers.0.self_attn.q_proj.weight" )
1419+ )
1420+
1421+ def test_fuse_src_to_same_tgt_params_qkv (self ):
1422+ tgt_path = (
1423+ "vllm_model.language_model.model.layers.0.self_attn.qkv_proj.weight"
1424+ )
1425+ q_val = jnp .ones ((8 , 16 , 64 )) * 1.0 # (num_heads, d_model, head_dim)
1426+ kv_val = (
1427+ jnp .ones ((2 , 2 , 16 , 64 )) * 2.0
1428+ ) # (2, num_kv_heads, d_model, head_dim)
1429+
1430+ fuse_sources = {}
1431+ # First call with Q
1432+ utils .fuse_src_to_same_tgt_params (
1433+ q_val ,
1434+ "layers.0.attn.q_einsum.w" ,
1435+ fuse_sources ,
1436+ tgt_path ,
1437+ None ,
1438+ tp_size = 1 ,
1439+ )
1440+ self .assertLen (fuse_sources [tgt_path ], 1 )
1441+
1442+ # Second call with KV
1443+ utils .fuse_src_to_same_tgt_params (
1444+ kv_val ,
1445+ "layers.0.attn.kv_einsum.w" ,
1446+ fuse_sources ,
1447+ tgt_path ,
1448+ None ,
1449+ tp_size = 1 ,
1450+ )
1451+
1452+ # Should be fused now
1453+ self .assertLen (fuse_sources [tgt_path ], 1 )
1454+ fused_key = "layers.0.attn.qkv_fused"
1455+ self .assertIn (fused_key , fuse_sources [tgt_path ])
1456+
1457+ fused_val = fuse_sources [tgt_path ][fused_key ][0 ]
1458+ # Expected shape: (q_per_tp + 2*kv_per_tp * head_dim, d_model)
1459+ # -> (d_model, (num_heads + 2*num_kv) * head_dim)
1460+ # transposed to ((num_heads+2*kv)*head_dim, d_model)
1461+ # q: (8, 16, 64) -> (16, 8, 64)
1462+ # kv: (2, 2, 16, 64) -> (16, 2, 2, 64) -> (16, 4, 64)
1463+ # concat(q, k, v) -> (16, 12, 64) -> (16, 768)
1464+ # transpose -> (768, 16)
1465+ self .assertLen (fused_val .shape , (768 , 16 ))
1466+ self .assertTrue (jnp .allclose (fused_val [:512 , :], 1.0 )) # Q
1467+ self .assertTrue (jnp .allclose (fused_val [512 :, :], 2.0 )) # KV
1468+
1469+ def test_fuse_src_to_same_tgt_params_gate_up (self ):
1470+ tgt_path = (
1471+ "vllm_model.language_model.model.layers.0.mlp.gate_up_proj.weight"
1472+ )
1473+ gate_val = jnp .ones ((16 , 32 )) * 3.0 # (d_model, hidden)
1474+ up_val = jnp .ones ((16 , 32 )) * 4.0 # (d_model, hidden)
1475+
1476+ fuse_sources = {}
1477+ utils .fuse_src_to_same_tgt_params (
1478+ gate_val ,
1479+ "layers.0.mlp.gate_proj.kernel" ,
1480+ fuse_sources ,
1481+ tgt_path ,
1482+ None ,
1483+ tp_size = 1 ,
1484+ )
1485+ utils .fuse_src_to_same_tgt_params (
1486+ up_val ,
1487+ "layers.0.mlp.up_proj.kernel" ,
1488+ fuse_sources ,
1489+ tgt_path ,
1490+ None ,
1491+ tp_size = 1 ,
1492+ )
1493+
1494+ fused_key = "layers.0.mlp.gate_up_fused"
1495+ fused_val = fuse_sources [tgt_path ][fused_key ][0 ]
1496+
1497+ # Hidden=32. tp_size=1. Gate and Up are stacked:
1498+ # (2*Hidden, d_model) = (64, 16)
1499+ self .assertEqual (fused_val .shape , (64 , 16 ))
1500+ self .assertTrue (jnp .allclose (fused_val [0 :32 , :], 3.0 )) # gate
1501+ self .assertTrue (jnp .allclose (fused_val [32 :64 , :], 4.0 )) # up
1502+
1503+ def test_fuse_src_to_same_tgt_params_gate_up_tp2 (self ):
1504+ tgt_path = (
1505+ "vllm_model.language_model.model.layers.0.mlp.gate_up_proj.weight"
1506+ )
1507+ gate_val = jnp .ones ((16 , 32 )) * 3.0
1508+ up_val = jnp .ones ((16 , 32 )) * 4.0
1509+
1510+ fuse_sources = {}
1511+ utils .fuse_src_to_same_tgt_params (
1512+ gate_val ,
1513+ "layers.0.mlp.gate_proj.kernel" ,
1514+ fuse_sources ,
1515+ tgt_path ,
1516+ None ,
1517+ tp_size = 2 ,
1518+ )
1519+ utils .fuse_src_to_same_tgt_params (
1520+ up_val ,
1521+ "layers.0.mlp.up_proj.kernel" ,
1522+ fuse_sources ,
1523+ tgt_path ,
1524+ None ,
1525+ tp_size = 2 ,
1526+ )
1527+
1528+ fused_val = fuse_sources [tgt_path ]["layers.0.mlp.gate_up_fused" ][0 ]
1529+
1530+ # Hidden=32, tp_size=2. chunk_size=16.
1531+ # [gate[0:16], up[0:16], gate[16:32], up[16:32]] interleaved
1532+ self .assertEqual (fused_val .shape , (64 , 16 ))
1533+ self .assertTrue (jnp .allclose (fused_val [0 :16 , :], 3.0 ))
1534+ self .assertTrue (jnp .allclose (fused_val [16 :32 , :], 4.0 ))
1535+ self .assertTrue (jnp .allclose (fused_val [32 :48 , :], 3.0 ))
1536+ self .assertTrue (jnp .allclose (fused_val [48 :64 , :], 4.0 ))
1537+
1538+ def test_align_shape_moe_gating_einsum (self ):
1539+ val = jnp .ones ((2 , 2 , 128 , 16 ))
1540+ src_key = "layers.0.moe.gating_einsum"
1541+ tgt_shape = (2 , 16 , 256 )
1542+
1543+ result = utils ._align_shape (val , tgt_shape , src_key , tp_size = 1 )
1544+ self .assertEqual (result .shape , (2 , 16 , 256 ))
1545+ self .assertTrue (jnp .allclose (result , 1.0 ))
1546+
1547+ # Test with padding
1548+ val_small = jnp .ones ((2 , 2 , 100 , 16 ))
1549+ # chunk_size = 100. padded = 128. pad_amount = 28.
1550+ # result shape should be (2, 16, 256)
1551+ result_padded = utils ._align_shape (
1552+ val_small , (2 , 16 , 256 ), src_key , tp_size = 1
1553+ )
1554+ self .assertEqual (result_padded .shape , (2 , 16 , 256 ))
1555+ # Check that padded area is 0
1556+ # gate_chunks (2, 1, 100, 16) -> pad -> (2, 1, 128, 16) stack ->
1557+ # (2, 1, 2, 128, 16) -> reshape (2, 256, 16) ->
1558+ # transpose (2, 16,256) The first 100 of first 128 should be 1.
1559+ # The last 28 of first 128 should be 0.
1560+ np .testing .assert_array_equal (result_padded [0 , 0 , 100 :128 ], 0.0 )
1561+ np .testing .assert_array_equal (result_padded [0 , 0 , 0 :100 ], 1.0 )
14061562
14071563 def test_transfer_state_directly_fuses_moe_weights (self ):
14081564 """Tests that wi_0 and wi_1 are fused into wi when target expects it."""
14091565 wi_0_val = jnp .array ([[1.0 , 2.0 ], [5.0 , 6.0 ]], dtype = jnp .float32 )
14101566 wi_1_val = jnp .array ([[3.0 , 4.0 ], [7.0 , 8.0 ]], dtype = jnp .float32 )
1411-
1567+
14121568 src_state = nnx .Dict (
14131569 layers = nnx .Dict (
14141570 wi_0 = nnx .Param (wi_0_val ),
14151571 wi_1 = nnx .Param (wi_1_val ),
14161572 )
14171573 )
1418-
1574+
14191575 dst_state = nnx .Dict (
14201576 layers = nnx .Dict (
14211577 wi = nnx .Param (jnp .zeros ((2 , 4 ), dtype = jnp .float32 ))
0 commit comments