diff --git a/tests/generate/utils_test.py b/tests/generate/utils_test.py index f14ff5941..1d51148c6 100644 --- a/tests/generate/utils_test.py +++ b/tests/generate/utils_test.py @@ -1290,6 +1290,355 @@ def test_transfer_state_directly_implicit_layers_container(self): jnp.array(200.0), ) + def test_transfer_state_directly_gemma4_scanned_blocks(self): + """Tests Gemma4 (MaxText) source layout: scanned_blocks wrapper + size-1 scan dim. + + Gemma4 in MaxText stores per-layer params under + `decoder.scanned_blocks.layers_X.` and keeps the scan axis as a + size-1 dim inside each chunk (e.g. wi_0 shape `(experts, 1, embed, mlp)`). + The destination drops both the wrapper and the size-1 dim. + """ + # Source: per-layer chunks under scanned_blocks with size-1 scan dim at + # axis 1 (MaxText's default scan_axis). + src_state = nnx.Dict( + base=nnx.Dict( + decoder=nnx.Dict( + decoder_norm=nnx.Dict( + scale=nnx.Param(jnp.array([1.0, 2.0, 3.0])) + ), + scanned_blocks=nnx.Dict( + layers_0=nnx.Dict( + mlp=nnx.Dict( + wo=nnx.Param( + jnp.arange(12, dtype=jnp.float32) + .reshape(2, 1, 2, 3) + ) + ), + norm=nnx.Dict( + scale=nnx.Param(jnp.array([[10.0], [11.0], [12.0]])) + ), + ), + layers_1=nnx.Dict( + mlp=nnx.Dict( + wo=nnx.Param( + (jnp.arange(12, dtype=jnp.float32) + 100) + .reshape(2, 1, 2, 3) + ) + ), + norm=nnx.Dict( + scale=nnx.Param(jnp.array([[20.0], [21.0], [22.0]])) + ), + ), + ), + ) + ) + ) + + # Destination: no scanned_blocks wrapper and no size-1 scan dim. + dst_state = nnx.Dict( + decoder=nnx.Dict( + decoder_norm=nnx.Dict( + scale=nnx.Param(jnp.zeros((3,), dtype=jnp.float32)) + ), + layers_0=nnx.Dict( + mlp=nnx.Dict( + wo=nnx.Param(jnp.zeros((2, 2, 3), dtype=jnp.float32)) + ), + norm=nnx.Dict( + scale=nnx.Param(jnp.zeros((3,), dtype=jnp.float32)) + ), + ), + layers_1=nnx.Dict( + mlp=nnx.Dict( + wo=nnx.Param(jnp.zeros((2, 2, 3), dtype=jnp.float32)) + ), + norm=nnx.Dict( + scale=nnx.Param(jnp.zeros((3,), dtype=jnp.float32)) + ), + ), + ) + ) + + mock_reshard = lambda source, target: source + utils.transfer_state_directly( + src_state, dst_state, reshard_fn=mock_reshard, scan_axis=1 + ) + + # Non-layer param under decoder transfers unchanged. + np.testing.assert_array_equal( + dst_state['decoder']['decoder_norm']['scale'][...], + jnp.array([1.0, 2.0, 3.0]), + ) + # Each layer chunk is squeezed along the size-1 scan_axis. + np.testing.assert_array_equal( + dst_state['decoder']['layers_0']['mlp']['wo'][...], + jnp.arange(12, dtype=jnp.float32).reshape(2, 2, 3), + ) + np.testing.assert_array_equal( + dst_state['decoder']['layers_0']['norm']['scale'][...], + jnp.array([10.0, 11.0, 12.0]), + ) + np.testing.assert_array_equal( + dst_state['decoder']['layers_1']['mlp']['wo'][...], + (jnp.arange(12, dtype=jnp.float32) + 100).reshape(2, 2, 3), + ) + np.testing.assert_array_equal( + dst_state['decoder']['layers_1']['norm']['scale'][...], + jnp.array([20.0, 21.0, 22.0]), + ) + + def test_transfer_state_directly_gemma4_multi_chunk_scanned_blocks(self): + """Tests Gemma4 (MaxText) source layout where layers are split across + multiple scanned chunks (e.g. `layers_0` for layers [0,1], `layers_2` for + layers [2,3]) — each chunk has `chunk_size` consecutive layers stacked at + `scan_axis`. The destination is fully unrolled. + """ + # Source: two chunks of 2 layers each, scan dim of size 2 at axis 1. + # `wo` per-layer shape is (2, 3); chunks store (2, 2, 3). + # `layer_scalar` per-layer shape is (1,); chunks store (1, 2). + src_state = nnx.Dict( + base=nnx.Dict( + decoder=nnx.Dict( + scanned_blocks=nnx.Dict( + layers_0=nnx.Dict( + mlp=nnx.Dict( + wo=nnx.Param( + jnp.arange(12, dtype=jnp.float32).reshape(2, 2, 3) + ) + ), + layer_scalar=nnx.Param(jnp.array([[1.0, 2.0]])), + ), + layers_2=nnx.Dict( + mlp=nnx.Dict( + wo=nnx.Param( + (jnp.arange(12, dtype=jnp.float32) + 100) + .reshape(2, 2, 3) + ) + ), + layer_scalar=nnx.Param(jnp.array([[3.0, 4.0]])), + ), + ), + ) + ) + ) + + # Destination: 4 fully-unrolled layers, no scanned_blocks wrapper, no + # scan dim. + dst_state = nnx.Dict( + decoder=nnx.Dict( + layers_0=nnx.Dict( + mlp=nnx.Dict(wo=nnx.Param(jnp.zeros((2, 3), dtype=jnp.float32))), + layer_scalar=nnx.Param(jnp.zeros((1,), dtype=jnp.float32)), + ), + layers_1=nnx.Dict( + mlp=nnx.Dict(wo=nnx.Param(jnp.zeros((2, 3), dtype=jnp.float32))), + layer_scalar=nnx.Param(jnp.zeros((1,), dtype=jnp.float32)), + ), + layers_2=nnx.Dict( + mlp=nnx.Dict(wo=nnx.Param(jnp.zeros((2, 3), dtype=jnp.float32))), + layer_scalar=nnx.Param(jnp.zeros((1,), dtype=jnp.float32)), + ), + layers_3=nnx.Dict( + mlp=nnx.Dict(wo=nnx.Param(jnp.zeros((2, 3), dtype=jnp.float32))), + layer_scalar=nnx.Param(jnp.zeros((1,), dtype=jnp.float32)), + ), + ) + ) + + mock_reshard = lambda source, target: source + utils.transfer_state_directly( + src_state, dst_state, reshard_fn=mock_reshard, scan_axis=1 + ) + + # layers_0 direct-matches src layers_0 chunk; within-chunk slot 0. + np.testing.assert_array_equal( + dst_state['decoder']['layers_0']['mlp']['wo'][...], + jnp.arange(12, dtype=jnp.float32).reshape(2, 2, 3)[:, 0, :], + ) + np.testing.assert_array_equal( + dst_state['decoder']['layers_0']['layer_scalar'][...], + jnp.array([1.0]), + ) + # layers_1 falls into scanned-candidate path; src layers_0 chunk, slot 1. + np.testing.assert_array_equal( + dst_state['decoder']['layers_1']['mlp']['wo'][...], + jnp.arange(12, dtype=jnp.float32).reshape(2, 2, 3)[:, 1, :], + ) + np.testing.assert_array_equal( + dst_state['decoder']['layers_1']['layer_scalar'][...], + jnp.array([2.0]), + ) + # layers_2 direct-matches src layers_2 chunk; within-chunk slot 0. + np.testing.assert_array_equal( + dst_state['decoder']['layers_2']['mlp']['wo'][...], + (jnp.arange(12, dtype=jnp.float32) + 100).reshape(2, 2, 3)[:, 0, :], + ) + np.testing.assert_array_equal( + dst_state['decoder']['layers_2']['layer_scalar'][...], + jnp.array([3.0]), + ) + # layers_3 falls into scanned-candidate path; src layers_2 chunk, slot 1. + np.testing.assert_array_equal( + dst_state['decoder']['layers_3']['mlp']['wo'][...], + (jnp.arange(12, dtype=jnp.float32) + 100).reshape(2, 2, 3)[:, 1, :], + ) + np.testing.assert_array_equal( + dst_state['decoder']['layers_3']['layer_scalar'][...], + jnp.array([4.0]), + ) + + def test_transfer_state_directly_gemma4_moe_interleaved_chunks(self): + """Gemma4 MaxText MoE: per-slot `layers_K` chunks with `num_reps>1` along + `scan_axis`. Target layer X maps to (slot=X%n, rep=X//n) where n is the + number of chunk keys (NUM_SLOTS). Within `layers_{slot}` the scan-axis + index is `rep`. Source `wi_0`/`wi_1` fuse into target `wi`. + """ + num_slots = 3 + num_reps = 2 + num_experts = 2 + embed = 2 + inner = 2 + total_layers = num_slots * num_reps + + # Build distinct values per (slot, rep, experts, embed, inner). + def _chunk(base): + vals = jnp.arange( + num_experts * num_reps * embed * inner, dtype=jnp.float32 + ).reshape(num_experts, num_reps, embed, inner) + return vals + base + + src_state = nnx.Dict( + base=nnx.Dict( + decoder=nnx.Dict( + scanned_blocks=nnx.Dict(**{ + f'layers_{slot}': nnx.Dict( + mlp=nnx.Dict( + wi_0=nnx.Param(_chunk(slot * 1000.0)), + wi_1=nnx.Param(_chunk(slot * 1000.0 + 500.0)), + ) + ) + for slot in range(num_slots) + }), + ) + ) + ) + + dst_state = nnx.Dict( + decoder=nnx.Dict(**{ + f'layers_{i}': nnx.Dict( + mlp=nnx.Dict( + wi=nnx.Param( + jnp.zeros( + (num_experts, embed, 2 * inner), dtype=jnp.float32 + ) + ) + ) + ) + for i in range(total_layers) + }) + ) + + mock_reshard = lambda source, target: source + utils.transfer_state_directly( + src_state, dst_state, reshard_fn=mock_reshard, scan_axis=1 + ) + + for i in range(total_layers): + slot = i % num_slots + rep = i // num_slots + wi_0_full = _chunk(slot * 1000.0) + wi_1_full = _chunk(slot * 1000.0 + 500.0) + # Per-layer slice at scan_axis=1. + expected_wi_0 = wi_0_full[:, rep, :, :] + expected_wi_1 = wi_1_full[:, rep, :, :] + expected = jnp.concatenate([expected_wi_0, expected_wi_1], axis=-1) + np.testing.assert_array_equal( + dst_state['decoder'][f'layers_{i}']['mlp']['wi'][...], + expected, + ) + + def test_transfer_state_directly_gemma4_dense_interleaved_chunks(self): + """Gemma4 MaxText dense path: per-slot `layers_K` chunks with `num_reps>1` + for a non-MoE param. Exercises Candidate C (non-MoE scanned chunk) under + the interleaved layout where chunk K = slot and within-chunk index = rep. + """ + num_slots = 3 + num_reps = 2 + in_dim = 2 + out_dim = 3 + total_layers = num_slots * num_reps + + def _chunk(base): + # Shape (in_dim, num_reps, out_dim) — scan axis at position 1. + return jnp.arange( + in_dim * num_reps * out_dim, dtype=jnp.float32 + ).reshape(in_dim, num_reps, out_dim) + base + + src_state = nnx.Dict( + base=nnx.Dict( + decoder=nnx.Dict( + scanned_blocks=nnx.Dict(**{ + f'layers_{slot}': nnx.Dict( + mlp=nnx.Dict(wo=nnx.Param(_chunk(slot * 1000.0))) + ) + for slot in range(num_slots) + }), + ) + ) + ) + + dst_state = nnx.Dict( + decoder=nnx.Dict(**{ + f'layers_{i}': nnx.Dict( + mlp=nnx.Dict( + wo=nnx.Param(jnp.zeros((in_dim, out_dim), dtype=jnp.float32)) + ) + ) + for i in range(total_layers) + }) + ) + + mock_reshard = lambda source, target: source + utils.transfer_state_directly( + src_state, dst_state, reshard_fn=mock_reshard, scan_axis=1 + ) + + for i in range(total_layers): + slot = i % num_slots + rep = i // num_slots + expected = _chunk(slot * 1000.0)[:, rep, :] + np.testing.assert_array_equal( + dst_state['decoder'][f'layers_{i}']['mlp']['wo'][...], + expected, + ) + + def test_transfer_state_directly_keeps_scanned_blocks_when_target_has_it(self): + """When destination also has scanned_blocks, source paths are not lifted.""" + src_state = nnx.Dict( + decoder=nnx.Dict( + scanned_blocks=nnx.Dict( + layers_0=nnx.Dict(weight=nnx.Param(jnp.array([1.0, 2.0]))), + ), + ) + ) + dst_state = nnx.Dict( + decoder=nnx.Dict( + scanned_blocks=nnx.Dict( + layers_0=nnx.Dict( + weight=nnx.Param(jnp.zeros((2,), dtype=jnp.float32)) + ), + ), + ) + ) + + mock_reshard = lambda source, target: source + utils.transfer_state_directly(src_state, dst_state, reshard_fn=mock_reshard) + + np.testing.assert_array_equal( + dst_state['decoder']['scanned_blocks']['layers_0']['weight'][...], + jnp.array([1.0, 2.0]), + ) + def test_transfer_state_directly_with_dtype_casting(self): """Tests that transfer_state_directly correctly casts dtypes (e.g., f32 to bf16).""" # Source state in float32 @@ -1588,6 +1937,133 @@ def test_transfer_state_directly_fuses_moe_weights_scanned_to_unrolled(self): jnp.concatenate([wi_0_val[:, 1, :], wi_1_val[:, 1, :]], axis=-1), ) + def test_transfer_state_directly_fuses_per_expert_scale_unscanned(self): + """fuse_expert_scales=True pre-multiplies `wo` by sibling `per_expert_scale` + on the source side (unscanned src/tgt). Mirrors MaxText's init-time + `wo *= per_expert_scale[:, None, None]` so that targets initialized with + `model_call_mode=='inference'` and `fuse_expert_scales=True` receive the + pre-fused product. `per_expert_scale` is still transferred to its + (now-vestigial) target slot. + """ + wo_val = jnp.arange(12, dtype=jnp.float32).reshape(2, 2, 3) + scale_val = jnp.array([2.0, 3.0], dtype=jnp.float32) + src_state = nnx.Dict( + moe=nnx.Dict( + wo=nnx.Param(wo_val), + per_expert_scale=nnx.Param(scale_val), + ) + ) + dst_state = nnx.Dict( + moe=nnx.Dict( + wo=nnx.Param(jnp.zeros((2, 2, 3), dtype=jnp.float32)), + per_expert_scale=nnx.Param(jnp.zeros((2,), dtype=jnp.float32)), + ) + ) + + mock_reshard = lambda source, target: source + utils.transfer_state_directly( + src_state, dst_state, reshard_fn=mock_reshard, + fuse_expert_scales=True, + ) + + expected_wo = wo_val * scale_val[:, None, None] + np.testing.assert_array_equal( + dst_state['moe']['wo'][...], expected_wo + ) + # Scale is still transferred to its (vestigial) slot. + np.testing.assert_array_equal( + dst_state['moe']['per_expert_scale'][...], scale_val + ) + + def test_transfer_state_directly_per_expert_scale_disabled_by_default(self): + """Without the flag, `wo` and `per_expert_scale` transfer separately — + the new fusion is purely additive and off-by-default.""" + wo_val = jnp.arange(12, dtype=jnp.float32).reshape(2, 2, 3) + scale_val = jnp.array([2.0, 3.0], dtype=jnp.float32) + src_state = nnx.Dict( + moe=nnx.Dict( + wo=nnx.Param(wo_val), + per_expert_scale=nnx.Param(scale_val), + ) + ) + dst_state = nnx.Dict( + moe=nnx.Dict( + wo=nnx.Param(jnp.zeros((2, 2, 3), dtype=jnp.float32)), + per_expert_scale=nnx.Param(jnp.zeros((2,), dtype=jnp.float32)), + ) + ) + + mock_reshard = lambda source, target: source + utils.transfer_state_directly(src_state, dst_state, reshard_fn=mock_reshard) + + np.testing.assert_array_equal(dst_state['moe']['wo'][...], wo_val) + np.testing.assert_array_equal( + dst_state['moe']['per_expert_scale'][...], scale_val + ) + + def test_transfer_state_directly_fuses_per_expert_scale_scanned(self): + """Gemma4-style scanned src: per_expert_scale has shape (E, num_reps) and + wo has shape (E, num_reps, d_in, d_out) along scan_axis=1. Fusion happens + on the scanned source before the per-layer unrolling. + """ + num_experts = 2 + num_reps = 2 + d_in = 2 + d_out = 3 + wo_val = jnp.arange( + num_experts * num_reps * d_in * d_out, dtype=jnp.float32 + ).reshape(num_experts, num_reps, d_in, d_out) + scale_val = jnp.array( + [[2.0, 3.0], [4.0, 5.0]], dtype=jnp.float32 + ) # shape (E=2, L=2) + + src_state = nnx.Dict( + base=nnx.Dict( + decoder=nnx.Dict( + scanned_blocks=nnx.Dict( + layers_0=nnx.Dict( + moe=nnx.Dict( + wo=nnx.Param(wo_val), + per_expert_scale=nnx.Param(scale_val), + ) + ), + ) + ) + ) + ) + + dst_state = nnx.Dict( + decoder=nnx.Dict(**{ + f'layers_{i}': nnx.Dict( + moe=nnx.Dict( + wo=nnx.Param(jnp.zeros((num_experts, d_in, d_out), + dtype=jnp.float32)), + per_expert_scale=nnx.Param(jnp.zeros((num_experts,), + dtype=jnp.float32)), + ) + ) + for i in range(num_reps) + }) + ) + + mock_reshard = lambda source, target: source + utils.transfer_state_directly( + src_state, dst_state, reshard_fn=mock_reshard, + scan_axis=1, fuse_expert_scales=True, + ) + + # Each unrolled layer's wo == src_wo[:, rep] * src_scale[:, rep, None, None]. + for i in range(num_reps): + expected_wo = wo_val[:, i, :, :] * scale_val[:, i, None, None] + np.testing.assert_array_equal( + dst_state['decoder'][f'layers_{i}']['moe']['wo'][...], + expected_wo, + ) + np.testing.assert_array_equal( + dst_state['decoder'][f'layers_{i}']['moe']['per_expert_scale'][...], + scale_val[:, i], + ) + def test_transfer_state_directly_delete_dst_buffers_no_chunking(self): """delete_dst_buffers=True must never pass deleted arrays to reshard_fn.""" src_val = jnp.array([1.0, 2.0, 3.0]) diff --git a/tunix/generate/utils.py b/tunix/generate/utils.py index 0e31f125c..ffe0df7cc 100644 --- a/tunix/generate/utils.py +++ b/tunix/generate/utils.py @@ -1192,6 +1192,40 @@ def _align_to_model_shape( return _align_per_axis(src_val, tgt_val.shape, tgt_sharding, key_path) +def _maybe_squeeze_scan_axis( + src_val: jax.Array | np.ndarray, + tgt_val: jax.Array | np.ndarray, + scan_axis: int, + key_path: str, +) -> jax.Array | np.ndarray: + """Drops a vestigial size-1 scan dim from src when target has one less rank. + + Some MaxText layouts (notably Gemma4's `scanned_blocks` containers) keep a + size-1 scan dimension on every per-layer chunk even though each chunk + already corresponds to a single layer. The destination tensors do not carry + that axis, so we squeeze it out before the per-axis alignment helper sees + the rank mismatch. + """ + if not (hasattr(src_val, 'shape') and hasattr(tgt_val, 'shape')): + return src_val + if src_val.shape == tgt_val.shape: + return src_val + if len(src_val.shape) != len(tgt_val.shape) + 1: + return src_val + if scan_axis < 0 or scan_axis >= len(src_val.shape): + return src_val + if src_val.shape[scan_axis] != 1: + return src_val + logging.info( + 'Squeezing size-1 scan dim on %s: %s -> %s', + key_path, src_val.shape, + src_val.shape[:scan_axis] + src_val.shape[scan_axis + 1:], + ) + if hasattr(jnp, 'squeeze') and not isinstance(src_val, np.ndarray): + return jnp.squeeze(src_val, axis=scan_axis) + return np.squeeze(src_val, axis=scan_axis) + + def _bulk_align_and_unstack( arr: jax.Array | np.ndarray, scan_axis: int, @@ -1239,6 +1273,85 @@ def _bulk_align_and_unstack( return tuple(jnp.unstack(aligned, axis=scan_axis)) +def _resolve_scanned_chunk( + layer_idx: int, + candidates: List[Tuple[int, int]], +) -> Optional[Tuple[int, int]]: + """Maps a global target layer index to (chunk_key, within_chunk_idx). + + MaxText exposes two scanned-chunk layouts that share the `layers_K` naming + but differ in what `K` and the scan-axis index mean: + + * Consecutive (canonical MaxText `scanned_blocks`): keys are spaced + by `chunk_size`, i.e. {0, csz, 2*csz, ...}. Chunk `K` holds global + layers [K, K+csz). Within-chunk index = `layer_idx - K`. + + * Interleaved (Gemma4 in MaxText): keys are slot indices spanning + {0, 1, ..., n_slots-1} with `csz > 1`. Chunk `K` holds global layers + {K, K + n_slots, K + 2*n_slots, ...}, i.e. target layer + `rep * n_slots + slot` sits at slot=K, scan-axis index=rep. + + The two layouts are distinguishable by the key spacing whenever `csz > 1`. + When `csz == 1` (or only one chunk is present) both interpretations + coincide so the consecutive logic also returns the right answer. + + Args: + layer_idx: The global target layer index. + candidates: All sibling chunks visible at this prefix, each as + `(chunk_key_int, chunk_size_int)`. + + Returns: + `(chunk_key, within_chunk_idx)` if a chunk covers `layer_idx`, else + `None`. The caller is responsible for looking up the full source key + that corresponds to `chunk_key`. + """ + if not candidates: + return None + + # Rule 1: single chunk or every chunk has csz == 1 -> consecutive logic + # captures both interpretations (degenerate). + all_size_one = all(csz == 1 for _, csz in candidates) + if len(candidates) == 1 or all_size_one: + for k, csz in candidates: + if k <= layer_idx < k + csz: + return (k, layer_idx - k) + return None + + keys = sorted({k for k, _ in candidates}) + sizes = {csz for _, csz in candidates} + + # Common chunk size lets us reason about regular spacings. + if len(sizes) == 1: + csz = next(iter(sizes)) + n = len(keys) + # Rule 2: keys == [0, csz, 2*csz, ...] -> consecutive. + if keys == [i * csz for i in range(n)]: + target_key = (layer_idx // csz) * csz + if target_key in keys: + return (target_key, layer_idx - target_key) + return None + # Rule 3: keys == [0, 1, ..., n-1] with csz > 1 -> interleaved (Gemma4). + if csz > 1 and keys == list(range(n)): + slot = layer_idx % n + rep = layer_idx // n + if slot in keys and 0 <= rep < csz: + return (slot, rep) + return None + + # Rule 4: irregular/mixed layout. Defensive consecutive fallback. + logging.warning( + 'Irregular scanned-chunk layout: keys=%s sizes=%s; falling back to ' + 'consecutive interpretation.', keys, sorted(sizes), + ) + best_key = -1 + for k, csz in candidates: + if k <= layer_idx < k + csz and k > best_key: + best_key = k + if best_key < 0: + return None + return (best_key, layer_idx - best_key) + + def _scanned_sharding_from_per_layer( per_layer_sharding: Optional[jax.sharding.Sharding], scan_axis: int, @@ -1326,9 +1439,20 @@ def _fuse_moe_weights( wi_1_key = tgt_key[:-1] + ('wi_1',) if wi_0_key not in new_src_flat or wi_1_key not in new_src_flat: continue + tgt_val = tgt_flat[tgt_key] + wi_0_peek = new_src_flat[wi_0_key] + # Skip when the source has an extra dim (e.g. a baked-in scan_axis) that + # the per-layer target lacks. Eager fusion would silently mis-shape the + # result and crash on reshape; the scanned-source path in + # `intersect_trees` (via `_jit_fuse_and_unstack_moe`) handles that case. + if ( + hasattr(wi_0_peek, 'shape') + and hasattr(tgt_val, 'shape') + and len(wi_0_peek.shape) != len(tgt_val.shape) + ): + continue wi_0 = new_src_flat.pop(wi_0_key) wi_1 = new_src_flat.pop(wi_1_key) - tgt_val = tgt_flat[tgt_key] # Pick the fused axis as the last axis where src and tgt differ. For the # canonical wi_0/wi_1 -> wi case this is the last axis (the mlp dim). mismatched_axes = [ @@ -1348,6 +1472,73 @@ def _fuse_moe_weights( return new_src_flat +@jax.jit +def _multiply_wo_by_scale( + wo: jax.Array | np.ndarray, + per_expert_scale: jax.Array | np.ndarray, +) -> jax.Array | np.ndarray: + """`wo * per_expert_scale` with two trailing size-1 dims broadcast in. + + Works for any rank: per_expert_scale is reshaped to add (1, 1) at the end + so its shape matches `wo` in the leading dims and broadcasts across the + final two. For unscanned MoE: wo `(E, d_in, d_out)` * scale `(E,)` -> + scale reshaped `(E, 1, 1)`. For scanned: wo `(E, L, d_in, d_out)` * scale + `(E, L)` -> scale reshaped `(E, L, 1, 1)`. + """ + scale = jnp.reshape(per_expert_scale, per_expert_scale.shape + (1, 1)) + return (wo * scale.astype(wo.dtype)).astype(wo.dtype) + + +def _fuse_per_expert_scale_into_wo( + src_flat: Dict[Tuple[str, ...], jax.Array | np.ndarray], +) -> Dict[Tuple[str, ...], jax.Array | np.ndarray]: + """Pre-multiplies `wo` by its sibling `per_expert_scale` on the source side. + + Mirrors MaxText's `fuse_expert_scales=True` init-time optimization: when + inference targets bake `per_expert_scale` into `wo` at init time, the + trained source state must arrive already pre-fused, otherwise the sync + overwrites the target's fused product with raw `wo`. + + For each `(prefix, 'wo')` in `src_flat` with a sibling + `(prefix, 'per_expert_scale')`: + * Skip if shapes are incompatible (`wo.ndim - per_expert_scale.ndim` + must equal 2 — the two trailing MLP/embed dims broadcast over). + * Replace `src_flat[(prefix, 'wo')]` with the fused product. + * Leave `(prefix, 'per_expert_scale')` in place so it still transfers + to its (now-vestigial) target slot. + + Args: + src_flat: Flat dict of source key tuples to JAX arrays. + + Returns: + A new flat dict with `wo` entries pre-multiplied by `per_expert_scale` + wherever a matching sibling exists. + """ + new_src_flat = dict(src_flat) + for src_key in list(new_src_flat.keys()): + if not src_key or src_key[-1] != 'wo': + continue + scale_key = src_key[:-1] + ('per_expert_scale',) + if scale_key not in new_src_flat: + continue + wo = new_src_flat[src_key] + scale = new_src_flat[scale_key] + if ( + not hasattr(wo, 'shape') + or not hasattr(scale, 'shape') + or wo.ndim - scale.ndim != 2 + or wo.shape[: scale.ndim] != scale.shape + ): + continue + logging.info( + 'Fusing per_expert_scale into %s: wo=%s, scale=%s', + '.'.join(str(k) for k in src_key), + wo.shape, scale.shape, + ) + new_src_flat[src_key] = _multiply_wo_by_scale(wo, scale) + return new_src_flat + + def _collect_src_buffer_ids( src_flat: Mapping[Tuple[str, ...], jax.Array | np.ndarray | nnx.Variable], ) -> Optional[set[int]]: @@ -1492,6 +1683,7 @@ def transfer_state_directly( scan_axis: int = 1, delete_dst_buffers: bool = False, reshard_chunk_size: Optional[int] = None, + fuse_expert_scales: bool = False, ) -> None: """Transfers state directly by matching structure, stripping wrappers. @@ -1519,6 +1711,13 @@ def transfer_state_directly( start with roughly `10 * num_layers` for a dense transformer and tune downward if you still see fragmentation. When None (default) the original single-call reshard behavior is preserved. + fuse_expert_scales: When True, pre-multiplies MoE `wo` source tensors + by their sibling `per_expert_scale` before matching against the + target. Mirrors MaxText's `fuse_expert_scales=True` inference + optimization (the target's `wo` slot expects the product baked in). + `per_expert_scale` is still transferred normally to its target slot. + Off by default — set to True only when the target was initialized + with `model_call_mode=="inference"` and `fuse_expert_scales=True`. """ def safe_has_key(obj: Mapping[str, Any], key: str) -> bool: if isinstance(obj, dict): @@ -1576,6 +1775,25 @@ def intersect_trees( src_flat = traverse_util.flatten_dict(src) tgt_flat = traverse_util.flatten_dict(tgt_spec) + # Gemma4 (MaxText) wraps its per-layer chunks in a 'scanned_blocks' + # container that doesn't exist on the destination side. If the target + # tree never mentions 'scanned_blocks', lift its contents up so the + # downstream direct-match / scanned-layer logic sees a flat layer view. + tgt_has_scanned_blocks = any( + 'scanned_blocks' in key for key in tgt_flat + ) + if not tgt_has_scanned_blocks and any( + 'scanned_blocks' in key for key in src_flat + ): + logging.info("Lifting 'scanned_blocks' contents from source paths.") + src_flat = { + tuple(p for p in key if p != 'scanned_blocks'): val + for key, val in src_flat.items() + } + + if fuse_expert_scales: + src_flat = _fuse_per_expert_scale_into_wo(src_flat) + src_flat = _fuse_moe_weights(src_flat, tgt_flat) filtered_src_flat = {} @@ -1588,10 +1806,43 @@ def intersect_trees( for key_tuple, tgt_val in tgt_flat.items(): path_str = '.'.join(str(k) for k in key_tuple) + + # Locate which part of the path is 'layers_X' up-front so both the + # direct-match and scanned-candidate branches can use the layer index. + layer_idx = -1 + match_index = -1 + for i, part in enumerate(key_tuple): + # Optimization: Only check strings that look like layers + if isinstance(part, str) and part.startswith('layers_'): + m = layer_pattern.match(part) + if m: + layer_idx = int(m.group(1)) + match_index = i + break + # Try Direct Match if key_tuple in src_flat: src_val = src_flat[key_tuple] src_val = _apply_dtype_cast(src_val, tgt_val.dtype, path_str) + src_val = _maybe_squeeze_scan_axis( + src_val, tgt_val, scan_axis, path_str + ) + # Gemma4's scanned_blocks layout splits layers into chunks named + # `layers_K` where each chunk holds `chunk_size` consecutive layers + # with the scan dim baked in at `scan_axis`. A literal direct match + # on `layers_X.` therefore lands on a chunk whose first scan + # slot IS global layer X — i.e. within-chunk index 0. Slice that + # off before shape alignment. + if ( + match_index != -1 + and hasattr(src_val, 'shape') + and hasattr(tgt_val, 'shape') + and len(src_val.shape) == len(tgt_val.shape) + 1 + and 0 <= scan_axis < len(src_val.shape) + ): + idx = [slice(None)] * src_val.ndim + idx[scan_axis] = 0 + src_val = src_val[tuple(idx)] src_val = _align_to_model_shape(src_val, tgt_val, path_str) filtered_src_flat[key_tuple] = src_val filtered_tgt_flat[key_tuple] = tgt_val @@ -1601,19 +1852,6 @@ def intersect_trees( # We look for 'layers_X' in the path and try to map it to 'layers' (MaxText) # or remove it (GPT-OSS / implicit stack). - # Locate which part of the path is 'layers_X' - layer_idx = -1 - match_index = -1 - - for i, part in enumerate(key_tuple): - # Optimization: Only check strings that look like layers - if isinstance(part, str) and part.startswith('layers_'): - m = layer_pattern.match(part) - if m: - layer_idx = int(m.group(1)) - match_index = i - break - if match_index != -1: # Check different candidate path formats for scanned layers # Candidate A: Replace 'layers_X' with 'layers' (Standard MaxText) @@ -1624,12 +1862,57 @@ def intersect_trees( candidate_b = list(key_tuple) candidate_b.pop(match_index) + # Candidates A/B refer to a single scanned tensor that holds every + # layer in order; the within-chunk index equals the global layer + # index. Candidate C below (Gemma4 scanned_blocks chunks) resolves + # `(chunk_key, within_chunk_idx)` via `_resolve_scanned_chunk`. found_candidate = None + within_chunk_idx = layer_idx for cand in [tuple(candidate_a), tuple(candidate_b)]: if cand in src_flat: found_candidate = cand break + # Candidate C: Per-slot `layers_K` chunks. Two MaxText layouts share + # this naming and are auto-disambiguated by `_resolve_scanned_chunk`: + # - Consecutive (canonical scanned_blocks): chunk_size consecutive + # layers stacked at `scan_axis`; keys spaced by chunk_size. + # - Interleaved (Gemma4): chunk K holds layers {K, K+n, K+2n, ...} + # where n = num slot chunks; scan-axis index = rep. + if found_candidate is None: + prefix = key_tuple[:match_index] + suffix = key_tuple[match_index + 1:] + chunk_candidates: List[Tuple[int, int]] = [] + chunk_keys: Dict[int, Tuple[str, ...]] = {} + for s_key in src_flat: + if ( + len(s_key) != len(key_tuple) + or s_key[:match_index] != prefix + or s_key[match_index + 1:] != suffix + or not isinstance(s_key[match_index], str) + ): + continue + m = layer_pattern.match(s_key[match_index]) + if not m: + continue + s_val = src_flat[s_key] + if ( + not hasattr(s_val, 'shape') + or not hasattr(tgt_val, 'shape') + or len(s_val.shape) != len(tgt_val.shape) + 1 + or scan_axis < 0 + or scan_axis >= len(s_val.shape) + ): + continue + cand_chunk_key = int(m.group(1)) + chunk_size = s_val.shape[scan_axis] + chunk_candidates.append((cand_chunk_key, chunk_size)) + chunk_keys[cand_chunk_key] = s_key + resolved = _resolve_scanned_chunk(layer_idx, chunk_candidates) + if resolved is not None: + chunk_key, within_chunk_idx = resolved + found_candidate = chunk_keys[chunk_key] + if found_candidate: # Cache key includes per-layer target shape so distinct unrolled # targets with different padded shapes don't collide on the same @@ -1660,26 +1943,110 @@ def intersect_trees( src_val, scan_axis, tgt_val, candidate_path ) - # Extract the layer_idx-th element from the unstacked cache. - sliced_val = unstacked_cache[cache_key][layer_idx] + # Extract the within-chunk slot from the unstacked cache. + # `within_chunk_idx` was set to `layer_idx` for candidates A/B + # (single scanned tensor) or to the helper-resolved index for + # candidate C (per-slot chunks). + sliced_val = unstacked_cache[cache_key][within_chunk_idx] sliced_val = _align_to_model_shape(sliced_val, tgt_val, path_str) filtered_src_flat[key_tuple] = sliced_val filtered_tgt_flat[key_tuple] = tgt_val continue # MoE fusion case: target has 'layers_X/.../wi' but source has scanned - # 'layers/.../wi_0' and 'layers/.../wi_1'. Fuse the full stacked + # wi_0/wi_1 with one extra dim (the scan axis). Fuse the full stacked # tensors first, then unstack once via a JIT-compiled helper — avoids # N per-layer jnp.concatenate dispatches and 2N intermediate device # allocations that cause compilation pressure and memory fragmentation. + # + # MaxText exposes scanned MoE state under several prefixes depending + # on model config, so we probe in two tiers (tier 2 first because a + # source slot can incidentally share the target's `layers_K` name — + # we need chunk-key spacing to disambiguate): + # + # Tier 2 — per-slot `layers_K` chunks (Gemma4 MaxText): each chunk + # holds num_reps layers at `scan_axis`. Layout (consecutive vs + # interleaved) is auto-detected via `_resolve_scanned_chunk`, which + # also returns the correct within-chunk index. + # + # Tier 1 — canonical single scanned tensor (fallback): + # A. Replace 'layers_X' with 'layers' — canonical MaxText scanned. + # B. Drop 'layers_X' entirely — implicit container. + # Both tier-1 forms hold ALL layers in one tensor, so + # within_chunk_idx == layer_idx. if key_tuple and key_tuple[-1] == 'wi': - scanned_prefix = ( - key_tuple[:match_index] + ('layers',) + key_tuple[match_index + 1:-1] - ) - wi_0_key = scanned_prefix + ('wi_0',) - wi_1_key = scanned_prefix + ('wi_1',) - - if wi_0_key in src_flat and wi_1_key in src_flat: + scanned_prefix: Optional[Tuple[str, ...]] = None + within_chunk_idx_moe = layer_idx + + # Tier 2: per-slot `layers_K` sibling chunks. + moe_prefix = key_tuple[:match_index] + moe_suffix = key_tuple[match_index + 1 : -1] + moe_candidates: List[Tuple[int, int]] = [] + moe_prefix_map: Dict[int, Tuple[str, ...]] = {} + for s_key in src_flat: + if ( + len(s_key) != len(key_tuple) + or s_key[:match_index] != moe_prefix + or s_key[match_index + 1 : -1] != moe_suffix + or s_key[-1] != 'wi_0' + or not isinstance(s_key[match_index], str) + ): + continue + m = layer_pattern.match(s_key[match_index]) + if not m: + continue + cand_chunk_key = int(m.group(1)) + cand_wi_1 = s_key[:-1] + ('wi_1',) + if cand_wi_1 not in src_flat: + continue + s_val = src_flat[s_key] + if ( + not hasattr(s_val, 'shape') + or not hasattr(tgt_val, 'shape') + or len(s_val.shape) != len(tgt_val.shape) + 1 + or scan_axis < 0 + or scan_axis >= len(s_val.shape) + ): + continue + chunk_size = s_val.shape[scan_axis] + moe_candidates.append((cand_chunk_key, chunk_size)) + moe_prefix_map[cand_chunk_key] = s_key[:-1] + resolved = _resolve_scanned_chunk(layer_idx, moe_candidates) + if resolved is not None: + chunk_key, within_chunk_idx_moe = resolved + scanned_prefix = moe_prefix_map[chunk_key] + + # Tier 1: canonical single-tensor candidates (fallback). + if scanned_prefix is None: + tier1_replacements: Tuple[Tuple[str, ...], ...] = ( + ('layers',), + (), + ) + seen_prefixes: set[Tuple[str, ...]] = set() + for replacement in tier1_replacements: + candidate_prefix = ( + key_tuple[:match_index] + + replacement + + key_tuple[match_index + 1 : -1] + ) + if candidate_prefix in seen_prefixes: + continue + seen_prefixes.add(candidate_prefix) + cand_wi_0 = candidate_prefix + ('wi_0',) + cand_wi_1 = candidate_prefix + ('wi_1',) + if cand_wi_0 in src_flat and cand_wi_1 in src_flat: + cand_val = src_flat[cand_wi_0] + if ( + hasattr(cand_val, 'shape') + and hasattr(tgt_val, 'shape') + and len(cand_val.shape) == len(tgt_val.shape) + 1 + ): + scanned_prefix = candidate_prefix + break + + if scanned_prefix is not None: + wi_0_key = scanned_prefix + ('wi_0',) + wi_1_key = scanned_prefix + ('wi_1',) fused_scanned_key = scanned_prefix + ('wi_fused',) if fused_scanned_key not in unstacked_cache: scanned_prefix_path = '.'.join(str(k) for k in scanned_prefix) @@ -1708,7 +2075,10 @@ def intersect_trees( ) del wi_0_full, wi_1_full - sliced_val = unstacked_cache[fused_scanned_key][layer_idx] + # Tier 1 sets within_chunk_idx_moe == layer_idx (single tensor + # holds all layers); tier 2 sets it to the resolved chunk-local + # index from `_resolve_scanned_chunk`. + sliced_val = unstacked_cache[fused_scanned_key][within_chunk_idx_moe] sliced_val = _align_to_model_shape(sliced_val, tgt_val, path_str) filtered_src_flat[key_tuple] = sliced_val diff --git a/tunix/generate/vllm_sampler.py b/tunix/generate/vllm_sampler.py index 3986fb310..fbf9a3bb4 100644 --- a/tunix/generate/vllm_sampler.py +++ b/tunix/generate/vllm_sampler.py @@ -67,6 +67,7 @@ class VllmConfig: tensor_parallel_size: int = -1 expert_parallel_size: int = 1 reshard_chunk_size: Optional[int] = None + fuse_expert_scales: bool = False # vLLM engine args that can be directly passed in without additional processing, e.g. max_model_len, async_scheduling, etc. engine_kwargs: dataclasses.InitVar[Optional[Dict[str, Any]]] = None @@ -240,6 +241,7 @@ def update_params( reshard_fn=reshard.reshard_pytree, delete_dst_buffers=True, # Ensure old weights are deleted to free up HBM memory reshard_chunk_size=self.config.reshard_chunk_size, + fuse_expert_scales=self.config.fuse_expert_scales, ) if self.llm is not None: diff --git a/tunix/rl/rollout/base_rollout.py b/tunix/rl/rollout/base_rollout.py index a91a2ea94..2d7fc7a3a 100644 --- a/tunix/rl/rollout/base_rollout.py +++ b/tunix/rl/rollout/base_rollout.py @@ -166,6 +166,10 @@ class RolloutConfig: # Set to a smaller value to reduce peak HBM pressure on large models. rollout_vllm_reshard_chunk_size: Optional[int] = None + # Whether to fuse expert scales into out-projection weights for MoE models during weight synchronization. + # This can improve decode performance for MoE models, but requires extra fusion time during weight synchronization. + rollout_vllm_fuse_expert_scales: bool = False + # Additional keyword arguments forwarded directly to the vLLM engine constructor. rollout_vllm_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict) diff --git a/tunix/rl/rollout/vllm_rollout.py b/tunix/rl/rollout/vllm_rollout.py index acdbb1181..04d212647 100644 --- a/tunix/rl/rollout/vllm_rollout.py +++ b/tunix/rl/rollout/vllm_rollout.py @@ -57,6 +57,7 @@ def __init__( data_parallel_size=rollout_config.data_parallel_size, expert_parallel_size=rollout_config.expert_parallel_size, reshard_chunk_size=rollout_config.rollout_vllm_reshard_chunk_size, + fuse_expert_scales=rollout_config.rollout_vllm_fuse_expert_scales, engine_kwargs={ "model": rollout_config.rollout_vllm_model_version, "max_model_len": cache_config_or_size,