diff --git a/experiments/grug/moe/README.md b/experiments/grug/moe/README.md index 2c99253070..41339dcd2f 100644 --- a/experiments/grug/moe/README.md +++ b/experiments/grug/moe/README.md @@ -37,8 +37,9 @@ z-loss only). The architecture choices are hardcoded in others use half. Specifically, layer `i` uses the long mask iff `i % 4 == 3`. - **Fp32 router path**: router logits cast to fp32 before top-k, softmax, and QB statistics. -- **Expert parallelism**: `ragged_all_to_all` or ring-based via - `levanter.grug.grug_moe.moe_mlp` (default: ring). Default capacity factor 1.0. +- **Expert parallelism**: ring, plain-XLA `assigned_token`, or DeepEP-backed + assigned-token transport via `levanter.grug.grug_moe.moe_mlp` (default: ring). + Default capacity factor 1.0. ## Scaling heuristic diff --git a/experiments/grug/moe/model.py b/experiments/grug/moe/model.py index 55e4cd5dea..77058a95a8 100644 --- a/experiments/grug/moe/model.py +++ b/experiments/grug/moe/model.py @@ -516,9 +516,7 @@ def __call__( hidden = self.token_embed.at[token_ids].get(out_sharding=batch_spec) hidden = self.embed_gated_norm(self.embed_norm(hidden)) - segment_ids = mask.segment_ids if isinstance(mask, AttentionMask) else None - short_mask = AttentionMask(is_causal=True, sliding_window=cfg.sliding_window // 2, segment_ids=segment_ids) - long_mask = AttentionMask(is_causal=True, sliding_window=cfg.sliding_window, segment_ids=segment_ids) + short_mask, long_mask = _model_sliding_attention_masks(mask, cfg) moe_router_stats: list[dict[str, jax.Array]] = [] for i, block in enumerate(self.blocks): @@ -583,6 +581,27 @@ def next_token_loss( return loss +def _model_sliding_attention_masks( + mask: AttentionMask | jax.Array, + cfg: GrugModelConfig, +) -> tuple[AttentionMask, AttentionMask]: + segment_ids = mask.segment_ids if isinstance(mask, AttentionMask) else None + thd_segment_metadata = mask.thd_segment_metadata if isinstance(mask, AttentionMask) else None + short_mask = AttentionMask( + is_causal=True, + sliding_window=cfg.sliding_window // 2, + segment_ids=segment_ids, + thd_segment_metadata=thd_segment_metadata, + ) + long_mask = AttentionMask( + is_causal=True, + sliding_window=cfg.sliding_window, + segment_ids=segment_ids, + thd_segment_metadata=thd_segment_metadata, + ) + return short_mask, long_mask + + def _init_weight(key: PRNGKeyArray, shape: tuple[int, ...], std: float) -> Float[Array, "..."]: return std * random.truncated_normal(key, -3, 3, shape) diff --git a/experiments/grug/moe/test_model.py b/experiments/grug/moe/test_model.py new file mode 100644 index 0000000000..4fd9b54c22 --- /dev/null +++ b/experiments/grug/moe/test_model.py @@ -0,0 +1,25 @@ +# Copyright The Marin Authors +# SPDX-License-Identifier: Apache-2.0 + +import jax.numpy as jnp +from levanter.grug.attention import AttentionMask, ThdSegmentMetadata + +from experiments.grug.moe.model import GrugModelConfig, _model_sliding_attention_masks + + +def test_model_sliding_attention_masks_preserve_thd_metadata(): + cfg = GrugModelConfig(vocab_size=16, hidden_dim=8, num_layers=2, num_heads=2, num_kv_heads=1, sliding_window=8) + metadata = ThdSegmentMetadata( + segment_lengths=jnp.array([[4, 4]], dtype=jnp.int32), + num_segments=jnp.array([2], dtype=jnp.int32), + ) + mask = AttentionMask(is_causal=True, thd_segment_metadata=metadata) + + short_mask, long_mask = _model_sliding_attention_masks(mask, cfg) + + assert short_mask.sliding_window == 4 + assert long_mask.sliding_window == 8 + assert short_mask.thd_segment_metadata is metadata + assert long_mask.thd_segment_metadata is metadata + assert short_mask.segment_ids is None + assert long_mask.segment_ids is None diff --git a/lib/levanter/pyproject.toml b/lib/levanter/pyproject.toml index 52137defa2..dccceda489 100644 --- a/lib/levanter/pyproject.toml +++ b/lib/levanter/pyproject.toml @@ -217,5 +217,6 @@ filterwarnings = [ markers = [ "slow: marks tests as slow (deselect with '-m \"not slow\"')", "torch: mark tests that use Torch (deselect with '-m \"not torch\"')", + "timeout: override the default per-test timeout", ] asyncio_default_fixture_loop_scope = "function" diff --git a/lib/levanter/scripts/bench/bench_grug_moe_ep.py b/lib/levanter/scripts/bench/bench_grug_moe_ep.py new file mode 100644 index 0000000000..c7c1889188 --- /dev/null +++ b/lib/levanter/scripts/bench/bench_grug_moe_ep.py @@ -0,0 +1,552 @@ +# Copyright The Levanter Authors +# SPDX-License-Identifier: Apache-2.0 + +"""Perf smoke for Grug MoE expert-parallel transport backends.""" + +import argparse +import json +import statistics +import time +from dataclasses import dataclass +from functools import partial + +import jax +import jax.numpy as jnp +import numpy as np +from haliax.nn.ragged_dot import ragged_dot +from jax import shard_map +from jax.sharding import AxisType, Mesh, NamedSharding, PartitionSpec as P + +from levanter.grug._moe.common import split_moe_w13_output +from levanter.grug._moe.ep_common import _prefix_cap_counts +from levanter.grug.grug_moe import moe_mlp +from levanter.kernels.deepep import ( + deepep_collapse_local_assignments, + deepep_combine_intranode, + deepep_dispatch_intranode_with_assignments, + deepep_get_dispatch_layout, + transport_ffi, +) +from levanter.utils.activation import ActivationFunctionEnum + + +@dataclass(frozen=True) +class BenchResult: + implementation: str + compile_seconds: float + median_seconds: float + mean_seconds: float + tokens_per_second: float + + +@dataclass(frozen=True) +class DeepEPConfigOverride: + dispatch: transport_ffi.IntranodeConfig + combine: transport_ffi.IntranodeConfig + + +_DEEPEP_COMPONENT_STAGES = ( + "deepep_dispatch", + "deepep_dispatch_w13", + "deepep_dispatch_w13_w2", + "deepep_dispatch_w13_w2_collapse", + "deepep_full", +) + + +def _ep_mesh(expert_axis_size: int) -> Mesh: + devices = np.asarray(jax.devices()) + if devices.size % expert_axis_size != 0: + raise ValueError(f"device count {devices.size} must be divisible by expert axis size {expert_axis_size}") + data_axis_size = devices.size // expert_axis_size + return Mesh( + devices.reshape(data_axis_size, expert_axis_size, 1), + axis_names=("data", "expert", "model"), + axis_types=(AxisType.Explicit, AxisType.Explicit, AxisType.Explicit), + ) + + +def _balanced_topk_assignments(tokens: int, *, topk: int, num_experts: int) -> jax.Array: + token_ids = jnp.arange(tokens, dtype=jnp.int32)[:, None] + topk_offsets = jnp.arange(topk, dtype=jnp.int32)[None, :] + return (token_ids * topk + topk_offsets) % num_experts + + +def _make_inputs( + *, + tokens: int, + hidden_dim: int, + intermediate_dim: int, + num_experts: int, + topk: int, + dtype: jnp.dtype, +) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]: + k_x, k_weights, k_w13, k_w2 = jax.random.split(jax.random.key(6215), 4) + x = jax.random.normal(k_x, (tokens, hidden_dim), dtype=dtype) + selected_experts = _balanced_topk_assignments(tokens, topk=topk, num_experts=num_experts) + combine_weights = jax.nn.sigmoid(jax.random.normal(k_weights, (tokens, topk), dtype=jnp.float32)).astype(dtype) + w_up_gate = jax.random.normal(k_w13, (num_experts, hidden_dim, 2 * intermediate_dim), dtype=dtype) + w_down = jax.random.normal(k_w2, (num_experts, intermediate_dim, hidden_dim), dtype=dtype) + return x, selected_experts, combine_weights, w_up_gate, w_down + + +def _shard_inputs( + mesh: Mesh, + inputs: tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array], +) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]: + x, selected_experts, combine_weights, w_up_gate, w_down = inputs + batch_sharding = NamedSharding(mesh, P(("data", "expert"), None)) + expert_sharding = NamedSharding(mesh, P("expert", None, None)) + return ( + jax.sharding.reshard(x, batch_sharding), + jax.sharding.reshard(selected_experts, batch_sharding), + jax.sharding.reshard(combine_weights, batch_sharding), + jax.sharding.reshard(w_up_gate, expert_sharding), + jax.sharding.reshard(w_down, expert_sharding), + ) + + +def _forward_fn(implementation: str, mesh: Mesh, *, capacity_factor: float): + def run(x, selected_experts, combine_weights, w_up_gate, w_down): + return moe_mlp( + x, + selected_experts, + combine_weights, + w_up_gate, + w_down, + activation=ActivationFunctionEnum.silu, + implementation=implementation, + mesh=mesh, + capacity_factor=capacity_factor, + report_capacity_overflow=True, + ) + + return jax.jit(run) + + +def _forward_backward_fn(implementation: str, mesh: Mesh, *, capacity_factor: float): + def loss_fn(x, selected_experts, combine_weights, w_up_gate, w_down): + out, _dropped = moe_mlp( + x, + selected_experts, + combine_weights, + w_up_gate, + w_down, + activation=ActivationFunctionEnum.silu, + implementation=implementation, + mesh=mesh, + capacity_factor=capacity_factor, + report_capacity_overflow=True, + ) + return jnp.sum(out.astype(jnp.float32)) + + return jax.jit(jax.value_and_grad(loss_fn, argnums=(0, 2, 3, 4))) + + +def _deepep_component_local( + x_local, + selected_experts_local, + combine_weights_local, + moe_w13_local, + moe_w2_local, + *, + num_experts: int, + capacity_factor: float, + stage: str, + dispatch_config: transport_ffi.IntranodeConfig | None, + combine_config: transport_ffi.IntranodeConfig | None, +): + local_experts = moe_w13_local.shape[0] + ep_size = num_experts // local_experts + topk = selected_experts_local.shape[1] + local_capacity = int(np.ceil(capacity_factor * x_local.shape[0] * topk)) + local_capacity = max(local_experts, local_capacity) + max_recv_tokens = x_local.shape[0] * ep_size + + num_tokens_per_rank, num_tokens_per_expert, is_token_in_rank = deepep_get_dispatch_layout( + selected_experts_local, + num_ranks=ep_size, + num_experts=num_experts, + ) + ( + recv_x, + recv_topk_weights, + recv_src_idx, + rank_prefix_matrix, + channel_prefix_matrix, + recv_channel_prefix_matrix, + send_head, + local_group_sizes, + num_recv_tokens, + x_dispatch, + assignment_weights, + recv_token_indices, + assignment_destinations, + ) = deepep_dispatch_intranode_with_assignments( + x_local, + selected_experts_local, + combine_weights_local, + num_tokens_per_rank, + num_tokens_per_expert, + is_token_in_rank, + num_experts=num_experts, + dispatch_config=dispatch_config, + combine_config=combine_config, + max_recv_tokens=max_recv_tokens, + ) + accepted_group_sizes = _prefix_cap_counts(local_group_sizes, capacity=local_capacity) + x_dispatch = x_dispatch[:local_capacity] + assignment_weights = assignment_weights[:local_capacity] + recv_token_indices = recv_token_indices[:local_capacity] + if stage == "deepep_dispatch": + local_value = jnp.sum(x_dispatch.astype(jnp.float32)) + jnp.sum(assignment_weights.astype(jnp.float32)) + return jax.lax.psum(local_value, ("data", "expert")) + + w13_out = ragged_dot(x_dispatch, moe_w13_local, accepted_group_sizes) + if stage == "deepep_dispatch_w13": + return jax.lax.psum(jnp.sum(w13_out.astype(jnp.float32)), ("data", "expert")) + + moe_dim = moe_w2_local.shape[1] + gate, up = split_moe_w13_output(w13_out, intermediate_dim=moe_dim, interleaved=False) + out_dispatch = ragged_dot(jax.nn.silu(gate) * up, moe_w2_local, accepted_group_sizes) + if stage == "deepep_dispatch_w13_w2": + return jax.lax.psum(jnp.sum(out_dispatch.astype(jnp.float32)), ("data", "expert")) + + recv_out = deepep_collapse_local_assignments( + out_dispatch, + assignment_weights, + recv_token_indices, + assignment_destinations, + accepted_group_sizes, + num_recv_tokens, + recv_capacity=recv_x.shape[0], + ) + if stage == "deepep_dispatch_w13_w2_collapse": + return jax.lax.psum(jnp.sum(recv_out.astype(jnp.float32)), ("data", "expert")) + + if stage != "deepep_full": + raise AssertionError(f"Unhandled DeepEP component stage {stage!r}") + out_local, _ = deepep_combine_intranode( + recv_out, + recv_topk_weights, + recv_src_idx, + rank_prefix_matrix, + channel_prefix_matrix, + recv_channel_prefix_matrix, + send_head, + num_recv_tokens, + is_token_in_rank, + ) + return jax.lax.psum(jnp.sum(out_local.astype(jnp.float32)), ("data", "expert")) + + +def _deepep_component_fn( + stage: str, + mesh: Mesh, + *, + num_experts: int, + capacity_factor: float, + deepep_config: DeepEPConfigOverride | None, +): + batch_spec = P(("data", "expert"), None) + w_spec = P("expert", None, None) + shard_fn = shard_map( + partial( + _deepep_component_local, + num_experts=num_experts, + capacity_factor=capacity_factor, + stage=stage, + dispatch_config=deepep_config.dispatch if deepep_config is not None else None, + combine_config=deepep_config.combine if deepep_config is not None else None, + ), + mesh=mesh, + in_specs=(batch_spec, batch_spec, batch_spec, w_spec, w_spec), + out_specs=P(), + check_vma=False, + ) + return jax.jit(shard_fn) + + +def _block_until_ready(value): + jax.block_until_ready(value) + return value + + +def _time( + fn, args: tuple[jax.Array, ...], *, tokens: int, warmup: int, steps: int, implementation: str +) -> BenchResult: + start = time.perf_counter() + _block_until_ready(fn(*args)) + compile_seconds = time.perf_counter() - start + + for _ in range(warmup): + _block_until_ready(fn(*args)) + + samples: list[float] = [] + for _ in range(steps): + start = time.perf_counter() + _block_until_ready(fn(*args)) + samples.append(time.perf_counter() - start) + + median_seconds = statistics.median(samples) + mean_seconds = statistics.fmean(samples) + return BenchResult( + implementation=implementation, + compile_seconds=compile_seconds, + median_seconds=median_seconds, + mean_seconds=mean_seconds, + tokens_per_second=tokens / median_seconds, + ) + + +def _as_json(result: BenchResult) -> dict[str, float | str]: + return { + "implementation": result.implementation, + "compile_seconds": result.compile_seconds, + "median_seconds": result.median_seconds, + "mean_seconds": result.mean_seconds, + "tokens_per_second": result.tokens_per_second, + } + + +def _max_tree_abs_diff(reference, candidate) -> float: + max_diff = 0.0 + for reference_leaf, candidate_leaf in zip( + jax.tree_util.tree_leaves(reference), + jax.tree_util.tree_leaves(candidate), + strict=True, + ): + diff = jnp.max(jnp.abs(reference_leaf.astype(jnp.float32) - candidate_leaf.astype(jnp.float32))) + max_diff = max(max_diff, float(diff)) + return max_diff + + +def _deepep_config_override(args: argparse.Namespace) -> DeepEPConfigOverride | None: + dispatch_values = ( + args.deepep_dispatch_sms, + args.deepep_dispatch_max_send_tokens, + args.deepep_dispatch_max_recv_tokens, + ) + combine_values = ( + args.deepep_combine_sms, + args.deepep_combine_max_send_tokens, + args.deepep_combine_max_recv_tokens, + ) + if all(value is None for value in (*dispatch_values, *combine_values)): + return None + if any(value is None for value in (*dispatch_values, *combine_values)): + raise ValueError("DeepEP config override requires all dispatch and combine config fields") + + dispatch = transport_ffi.IntranodeConfig( + num_sms=args.deepep_dispatch_sms, + num_max_send_tokens=args.deepep_dispatch_max_send_tokens, + num_max_recv_tokens=args.deepep_dispatch_max_recv_tokens, + ) + combine = transport_ffi.IntranodeConfig( + num_sms=args.deepep_combine_sms, + num_max_send_tokens=args.deepep_combine_max_send_tokens, + num_max_recv_tokens=args.deepep_combine_max_recv_tokens, + ) + return DeepEPConfigOverride(dispatch=dispatch, combine=combine) + + +def _deepep_config_payload(config: DeepEPConfigOverride | None) -> dict[str, dict[str, int]] | None: + if config is None: + return None + return { + "dispatch": { + "num_sms": config.dispatch.num_sms, + "num_max_send_tokens": config.dispatch.num_max_send_tokens, + "num_max_recv_tokens": config.dispatch.num_max_recv_tokens, + }, + "combine": { + "num_sms": config.combine.num_sms, + "num_max_send_tokens": config.combine.num_max_send_tokens, + "num_max_recv_tokens": config.combine.num_max_recv_tokens, + }, + } + + +def _run_benchmark( + args: argparse.Namespace, + expert_axis_size: int, + deepep_config: DeepEPConfigOverride | None, +) -> None: + if deepep_config is not None and args.mode != "deepep_components": + raise ValueError("DeepEP config override is only supported with --mode deepep_components") + + tokens = args.batch_size * args.seq_len + dtype = jnp.bfloat16 if args.dtype == "bfloat16" else jnp.float32 + mesh = _ep_mesh(expert_axis_size) + inputs = _shard_inputs( + mesh, + _make_inputs( + tokens=tokens, + hidden_dim=args.hidden_dim, + intermediate_dim=args.intermediate_dim, + num_experts=args.num_experts, + topk=args.topk, + dtype=dtype, + ), + ) + + with jax.set_mesh(mesh): + if args.mode == "deepep_components": + results = [ + _time( + _deepep_component_fn( + stage, + mesh, + num_experts=args.num_experts, + capacity_factor=args.capacity_factor, + deepep_config=deepep_config, + ), + inputs, + tokens=tokens, + warmup=args.warmup, + steps=args.steps, + implementation=stage, + ) + for stage in _DEEPEP_COMPONENT_STAGES + ] + payload = { + "shape": { + "tokens": tokens, + "batch_size": args.batch_size, + "seq_len": args.seq_len, + "hidden_dim": args.hidden_dim, + "intermediate_dim": args.intermediate_dim, + "num_experts": args.num_experts, + "topk": args.topk, + "capacity_factor": args.capacity_factor, + "dtype": args.dtype, + "mode": args.mode, + }, + "mesh": { + "devices": len(jax.devices()), + "expert_axis_size": expert_axis_size, + }, + "deepep_config": _deepep_config_payload(deepep_config), + "results": [_as_json(result) for result in results], + } + rendered = json.dumps(payload, indent=2, sort_keys=True) + print(rendered) + if args.json_output is not None: + with open(args.json_output, "w", encoding="utf-8") as handle: + handle.write(rendered) + handle.write("\n") + return + + fn_factory = _forward_fn if args.mode == "forward" else _forward_backward_fn + implementations = tuple(args.implementations) + + ring_out, ring_dropped = _block_until_ready( + _forward_fn("ring", mesh, capacity_factor=args.capacity_factor)(*inputs) + ) + ring_forward_backward = None + if args.mode == "forward_backward": + ring_forward_backward = _block_until_ready( + _forward_backward_fn("ring", mesh, capacity_factor=args.capacity_factor)(*inputs) + ) + max_reference_abs = float(jnp.max(jnp.abs(ring_out.astype(jnp.float32)))) + correctness: dict[str, dict[str, float]] = {} + results: list[BenchResult] = [] + for implementation in implementations: + candidate_out, candidate_dropped = _block_until_ready( + _forward_fn(implementation, mesh, capacity_factor=args.capacity_factor)(*inputs) + ) + diff = jnp.abs(ring_out.astype(jnp.float32) - candidate_out.astype(jnp.float32)) + max_abs_diff = float(jnp.max(diff)) + correctness[implementation] = { + "max_abs_diff": max_abs_diff, + "mean_abs_diff": float(jnp.mean(diff)), + "max_relative_diff": max_abs_diff / max(max_reference_abs, 1.0), + "ring_dropped": int(ring_dropped), + "candidate_dropped": int(candidate_dropped), + "dropped_abs_diff": int(jnp.abs(ring_dropped - candidate_dropped)), + } + if ring_forward_backward is not None: + candidate_loss, candidate_grads = _block_until_ready( + _forward_backward_fn(implementation, mesh, capacity_factor=args.capacity_factor)(*inputs) + ) + ring_loss, ring_grads = ring_forward_backward + correctness[implementation]["loss_abs_diff"] = float( + jnp.abs(ring_loss.astype(jnp.float32) - candidate_loss.astype(jnp.float32)) + ) + correctness[implementation]["grad_max_abs_diff"] = _max_tree_abs_diff(ring_grads, candidate_grads) + results.append( + _time( + fn_factory(implementation, mesh, capacity_factor=args.capacity_factor), + inputs, + tokens=tokens, + warmup=args.warmup, + steps=args.steps, + implementation=implementation, + ) + ) + + payload = { + "shape": { + "tokens": tokens, + "batch_size": args.batch_size, + "seq_len": args.seq_len, + "hidden_dim": args.hidden_dim, + "intermediate_dim": args.intermediate_dim, + "num_experts": args.num_experts, + "topk": args.topk, + "capacity_factor": args.capacity_factor, + "dtype": args.dtype, + "mode": args.mode, + }, + "mesh": { + "devices": len(jax.devices()), + "expert_axis_size": expert_axis_size, + }, + "deepep_config": _deepep_config_payload(deepep_config), + "max_reference_abs": max_reference_abs, + "correctness_vs_ring": correctness, + "results": [_as_json(result) for result in results], + } + rendered = json.dumps(payload, indent=2, sort_keys=True) + print(rendered) + if args.json_output is not None: + with open(args.json_output, "w", encoding="utf-8") as handle: + handle.write(rendered) + handle.write("\n") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--batch-size", type=int, default=8) + parser.add_argument("--seq-len", type=int, default=4096) + parser.add_argument("--hidden-dim", type=int, default=2560) + parser.add_argument("--intermediate-dim", type=int, default=1280) + parser.add_argument("--num-experts", type=int, default=64) + parser.add_argument("--topk", type=int, default=4) + parser.add_argument("--capacity-factor", type=float, default=1.25) + parser.add_argument("--expert-axis-size", type=int, default=None) + parser.add_argument("--warmup", type=int, default=2) + parser.add_argument("--steps", type=int, default=5) + parser.add_argument("--mode", choices=("forward", "forward_backward", "deepep_components"), default="forward") + parser.add_argument("--dtype", choices=("bfloat16", "float32"), default="bfloat16") + parser.add_argument("--implementations", nargs="+", default=("ring", "deepep", "assigned_token")) + parser.add_argument("--json-output", type=str, default=None) + parser.add_argument("--deepep-dispatch-sms", type=int, default=None) + parser.add_argument("--deepep-dispatch-max-send-tokens", type=int, default=None) + parser.add_argument("--deepep-dispatch-max-recv-tokens", type=int, default=None) + parser.add_argument("--deepep-combine-sms", type=int, default=None) + parser.add_argument("--deepep-combine-max-send-tokens", type=int, default=None) + parser.add_argument("--deepep-combine-max-recv-tokens", type=int, default=None) + args = parser.parse_args() + + expert_axis_size = args.expert_axis_size or len(jax.devices()) + if "deepep" in args.implementations and len(jax.devices()) != expert_axis_size: + raise ValueError( + "DeepEP intranode transport currently requires the expert group to span all visible local GPUs; " + f"got visible_devices={len(jax.devices())}, expert_axis_size={expert_axis_size}" + ) + _run_benchmark(args, expert_axis_size, _deepep_config_override(args)) + + +if __name__ == "__main__": + main() diff --git a/lib/levanter/src/levanter/callbacks/labeled_eval.py b/lib/levanter/src/levanter/callbacks/labeled_eval.py index 2a1bea6eae..38fe40c091 100644 --- a/lib/levanter/src/levanter/callbacks/labeled_eval.py +++ b/lib/levanter/src/levanter/callbacks/labeled_eval.py @@ -129,7 +129,7 @@ def cb_labeled_lm_evaluate( prefix: str = "labeled_eval", eval_current: bool = True, eval_model: bool = True, - mp: jmp.Policy = None, + mp: jmp.Policy | None = None, ) -> Callable[[StepInfo], None]: """Build a training callback for periodic labeled LM loss evaluation.""" evaluator = LabeledEvaluator.for_labeled_examples( diff --git a/lib/levanter/src/levanter/grug/_moe/common.py b/lib/levanter/src/levanter/grug/_moe/common.py index b2006eb0b1..c802be5120 100644 --- a/lib/levanter/src/levanter/grug/_moe/common.py +++ b/lib/levanter/src/levanter/grug/_moe/common.py @@ -22,13 +22,13 @@ MoeActivation: TypeAlias = ActivationFunctionEnum | Callable[[jax.Array], jax.Array] MoeImplementation: TypeAlias = Literal[ "ring", # Expert-parallel all-gather + psum-scatter backend. - "ragged_all_to_all", # Expert-parallel ragged all-to-all backend. - "deepep", # Expert-parallel DeepEP intranode dispatch/combine backend. + "assigned_token", # Expert-parallel plain-XLA assigned-token backend. + "deepep", # Expert-parallel DeepEP-backed assigned-token backend. "scatter", # Single-process grouped GMM with scatter-add combine. "sonic", # Single-process raw Sonic Triton gather/combine backend. ] _VALID_MOE_IMPLEMENTATIONS = get_args(MoeImplementation) -_EP_MOE_IMPLEMENTATIONS = ("ring", "ragged_all_to_all", "deepep") +_EP_MOE_IMPLEMENTATIONS = ("ring", "assigned_token", "deepep") # Local means no collectives over an expert axis. These backends can still run # under ordinary data/model sharding through the no-EP shard_map path. _LOCAL_MOE_IMPLEMENTATIONS = ( diff --git a/lib/levanter/src/levanter/grug/_moe/ep_assigned_token.py b/lib/levanter/src/levanter/grug/_moe/ep_assigned_token.py new file mode 100644 index 0000000000..ccafab17d9 --- /dev/null +++ b/lib/levanter/src/levanter/grug/_moe/ep_assigned_token.py @@ -0,0 +1,213 @@ +# Copyright The Levanter Authors +# SPDX-License-Identifier: Apache-2.0 + +"""Assigned-token expert-parallel Grug MoE backend.""" + +import math +from collections.abc import Callable +from typing import NamedTuple + +import jax +import jax.numpy as jnp + +from haliax.jax_utils import tree_checkpoint_name +from haliax.nn.ragged_dot import ragged_dot +from levanter.grug._moe.common import ( + _CHECKPOINT_DISPATCH_INPUT, + _CHECKPOINT_DISPATCH_OUTPUT, + _CHECKPOINT_EXPERT_HIDDEN, + split_moe_w13_output, +) +from levanter.grug._moe.ep_common import ( + _clip_receiver_group_sizes, + _compact_by_keep_mask, + _expand_from_keep_mask, + _expert_prefix_keep_mask, + _local_permute_from_counts, + _permute_by_global_expert, + _shard_a2a_params, + _sort_activations, +) + + +class AssignedTokenDispatch(NamedTuple): + """Receiver-local routed assignments grouped for expert GMM.""" + + x_dispatch: jax.Array + assignment_weights: jax.Array + local_sorted_indices: jax.Array + local_group_sizes: jax.Array + sender_sorted_indices: jax.Array + sender_keep_mask: jax.Array + all_shard_assignment_counts: jax.Array + dropped_local: jax.Array + + +def _unpermute_weighted_from_global_expert( + weighted_assignments: jax.Array, + sorted_indices: jax.Array, + *, + tokens_per_shard: int, + topk: int, +) -> jax.Array: + unsorted = _sort_activations(weighted_assignments, jnp.argsort(sorted_indices)) + return jnp.einsum( + "tkd->td", + unsorted.reshape(tokens_per_shard, topk, -1), + preferred_element_type=jnp.float32, + ) + + +def _assigned_token_dispatch( + x_local: jax.Array, + selected_experts_local: jax.Array, + combine_weights_local: jax.Array, + *, + num_experts: int, + local_experts: int, + capacity_factor: float, +) -> AssignedTokenDispatch: + if num_experts % local_experts != 0: + raise ValueError( + f"num_experts={num_experts} must be divisible by local expert count={local_experts} in EP mode" + ) + + shard_id = jax.lax.axis_index("expert") + ep_size = num_experts // local_experts + tokens_per_shard = x_local.shape[0] + topk = selected_experts_local.shape[1] + assignments_per_shard = tokens_per_shard * topk + local_capacity = int(math.ceil(capacity_factor * assignments_per_shard)) + local_capacity = max(local_experts, local_capacity) + recv_capacity = local_capacity + + sorted_x, sorted_indices, group_sizes = _permute_by_global_expert( + x_local, + selected_experts_local, + num_experts=num_experts, + ) + sorted_weights = _sort_activations( + combine_weights_local.reshape(-1)[:, None].astype(x_local.dtype), sorted_indices + ) + all_group_sizes = jax.lax.all_gather(group_sizes.astype(jnp.int32), "expert") + clipped_group_sizes = _clip_receiver_group_sizes( + all_group_sizes, + local_expert_size=local_experts, + receiver_capacity=local_capacity, + ) + sender_group_sizes = clipped_group_sizes[shard_id] + keep_mask = _expert_prefix_keep_mask( + group_sizes.astype(jnp.int32), + sender_group_sizes, + total_size=assignments_per_shard, + ) + sorted_x = _compact_by_keep_mask(sorted_x, keep_mask) + sorted_weights = _compact_by_keep_mask(sorted_weights, keep_mask) + + all_shard_counts = jnp.sum(clipped_group_sizes.reshape(ep_size, ep_size, local_experts), axis=2) + input_offsets, send_sizes, output_offsets, recv_sizes = _shard_a2a_params(all_shard_counts, shard_id) + dispatch_out_shape = jnp.zeros((recv_capacity, x_local.shape[1]), dtype=x_local.dtype) + x_dispatched = jax.lax.ragged_all_to_all( + sorted_x, + dispatch_out_shape, + input_offsets, + send_sizes, + output_offsets, + recv_sizes, + axis_name="expert", + ) + weight_out_shape = jnp.zeros((recv_capacity, 1), dtype=x_local.dtype) + weights_dispatched = jax.lax.ragged_all_to_all( + sorted_weights, + weight_out_shape, + input_offsets, + send_sizes, + output_offsets, + recv_sizes, + axis_name="expert", + ) + x_dispatch, local_sorted_indices, local_group_sizes = _local_permute_from_counts( + x_dispatched, + clipped_group_sizes, + local_expert_size=local_experts, + shard_index=shard_id, + ) + assignment_weights = _sort_activations(weights_dispatched, local_sorted_indices).reshape(-1) + recv_valid = jnp.arange(recv_capacity, dtype=jnp.int32) < jnp.sum(recv_sizes, dtype=jnp.int32) + assignment_weights = jnp.where(recv_valid, assignment_weights, 0) + dropped_local = jnp.sum(group_sizes, dtype=jnp.int32) - jnp.sum(sender_group_sizes, dtype=jnp.int32) + return AssignedTokenDispatch( + x_dispatch=tree_checkpoint_name(x_dispatch, _CHECKPOINT_DISPATCH_INPUT), + assignment_weights=assignment_weights, + local_sorted_indices=local_sorted_indices, + local_group_sizes=local_group_sizes, + sender_sorted_indices=sorted_indices, + sender_keep_mask=keep_mask, + all_shard_assignment_counts=all_shard_counts, + dropped_local=dropped_local, + ) + + +def _moe_mlp_ep_assigned_token_local( + x_local: jax.Array, + selected_experts_local: jax.Array, + combine_weights_local: jax.Array, + moe_w13_local: jax.Array, + moe_w2_local: jax.Array, + *, + activation_fn: Callable[[jax.Array], jax.Array], + num_experts: int, + capacity_factor: float, +) -> tuple[jax.Array, jax.Array]: + local_experts = moe_w13_local.shape[0] + shard_id = jax.lax.axis_index("expert") + tokens_per_shard = x_local.shape[0] + topk = selected_experts_local.shape[1] + assignments_per_shard = tokens_per_shard * topk + + with jax.named_scope("dispatch"): + dispatch = _assigned_token_dispatch( + x_local, + selected_experts_local, + combine_weights_local, + num_experts=num_experts, + local_experts=local_experts, + capacity_factor=capacity_factor, + ) + with jax.named_scope("moe_up_down"): + w13_out = tree_checkpoint_name( + ragged_dot(dispatch.x_dispatch, moe_w13_local, dispatch.local_group_sizes), + _CHECKPOINT_EXPERT_HIDDEN, + ) + moe_dim = moe_w2_local.shape[1] + gate, up = split_moe_w13_output(w13_out, intermediate_dim=moe_dim, interleaved=False) + out_dispatch = tree_checkpoint_name( + ragged_dot(activation_fn(gate) * up, moe_w2_local, dispatch.local_group_sizes), + _CHECKPOINT_DISPATCH_OUTPUT, + ) + + with jax.named_scope("combine"): + weighted_dispatch = out_dispatch * dispatch.assignment_weights[:, None] + local_output = _sort_activations(weighted_dispatch, jnp.argsort(dispatch.local_sorted_indices)) + return_out_shape = jnp.zeros((assignments_per_shard, x_local.shape[1]), dtype=local_output.dtype) + return_input_offsets, return_send_sizes, return_output_offsets, return_recv_sizes = _shard_a2a_params( + dispatch.all_shard_assignment_counts.T, shard_id + ) + returned = jax.lax.ragged_all_to_all( + local_output, + return_out_shape, + return_input_offsets, + return_send_sizes, + return_output_offsets, + return_recv_sizes, + axis_name="expert", + ) + returned = _expand_from_keep_mask(returned, dispatch.sender_keep_mask) + out_local = _unpermute_weighted_from_global_expert( + returned, + dispatch.sender_sorted_indices, + tokens_per_shard=tokens_per_shard, + topk=topk, + ).astype(x_local.dtype) + dropped_total = jax.lax.psum(dispatch.dropped_local, ("data", "expert")) + return out_local, dropped_total diff --git a/lib/levanter/src/levanter/grug/_moe/ep_deepep.py b/lib/levanter/src/levanter/grug/_moe/ep_deepep.py index d6b4cfc391..0a9afd5362 100644 --- a/lib/levanter/src/levanter/grug/_moe/ep_deepep.py +++ b/lib/levanter/src/levanter/grug/_moe/ep_deepep.py @@ -6,6 +6,7 @@ DeepEP source: https://github.com/deepseek-ai/DeepEP """ +import math from collections.abc import Callable from typing import NamedTuple @@ -21,7 +22,13 @@ _CHECKPOINT_EXPERT_HIDDEN, split_moe_w13_output, ) -from levanter.kernels.deepep import deepep_combine_intranode, deepep_dispatch_intranode, deepep_get_dispatch_layout +from levanter.grug._moe.ep_common import _prefix_cap_counts +from levanter.kernels.deepep import ( + deepep_collapse_local_assignments, + deepep_combine_intranode, + deepep_dispatch_intranode_with_assignments, + deepep_get_dispatch_layout, +) class DeepEPLocalAssignments(NamedTuple): @@ -40,63 +47,6 @@ class DeepEPLocalAssignments(NamedTuple): local_group_sizes: Int[Array, "EL"] -def _pack_deepep_local_assignments( - recv_x: Float[Array, "TR D"], - recv_topk_idx: Int[Array, "TR K"], - recv_topk_weights: Float[Array, "TR K"], - *, - local_experts: int, - num_recv_tokens: Int[Array, ""], -) -> DeepEPLocalAssignments: - with jax.named_scope("deepep_pack_local_assignments"): - max_recv_tokens, topk = recv_topk_idx.shape - total_assignments = max_recv_tokens * topk - - recv_token_indices = jnp.repeat(jnp.arange(max_recv_tokens, dtype=jnp.int32), topk) - expert_flat = recv_topk_idx.reshape(-1).astype(jnp.int32) - recv_valid = jnp.arange(max_recv_tokens, dtype=jnp.int32) < num_recv_tokens - local_mask = recv_valid[:, None] & (recv_topk_idx >= 0) & (recv_topk_idx < local_experts) - local_mask_flat = local_mask.reshape(-1) - local_bucket = jnp.where(local_mask_flat, expert_flat, local_experts) - local_group_sizes = jnp.bincount(local_bucket, length=local_experts + 1).astype(jnp.int32)[:-1] - total_valid = jnp.sum(local_group_sizes, dtype=jnp.int32) - - flat_positions = jnp.arange(total_assignments, dtype=jnp.int32) - order_key = local_bucket * total_assignments + flat_positions - max_order_key = (local_experts + 1) * total_assignments - selection_key = jnp.where(local_mask_flat, max_order_key - order_key, -1) - _, sorted_assignment_indices = jax.lax.top_k(selection_key, total_assignments) - - recv_token_indices = jnp.take(recv_token_indices, sorted_assignment_indices, axis=0) - x_dispatch = jnp.take(recv_x, recv_token_indices, axis=0) - assignment_weights = jnp.take(recv_topk_weights.reshape(-1), sorted_assignment_indices, axis=0).astype( - recv_x.dtype - ) - valid_sorted = jnp.arange(total_assignments, dtype=jnp.int32) < total_valid - x_dispatch = jnp.where(valid_sorted[:, None], x_dispatch, 0) - assignment_weights = jnp.where(valid_sorted, assignment_weights, 0) - return DeepEPLocalAssignments(x_dispatch, assignment_weights, recv_token_indices, local_group_sizes) - - -def _collapse_deepep_local_assignments( - out_dispatch: Float[Array, "TK D"], - assignment_weights: Float[Array, "TK"], - recv_token_indices: Int[Array, "TK"], - *, - recv_capacity: int, - num_recv_tokens: Int[Array, ""], -) -> Float[Array, "TR D"]: - with jax.named_scope("deepep_collapse_local_assignments"): - recv_out = jax.ops.segment_sum( - out_dispatch * assignment_weights[:, None], - recv_token_indices, - num_segments=recv_capacity, - indices_are_sorted=False, - ) - recv_valid = jnp.arange(recv_capacity, dtype=jnp.int32) < num_recv_tokens - return jnp.where(recv_valid[:, None], recv_out, 0) - - def _moe_mlp_ep_deepep_local( x_local: Float[Array, "TL D"], selected_experts_local: Int[Array, "TL K"], @@ -109,7 +59,6 @@ def _moe_mlp_ep_deepep_local( capacity_factor: float, ) -> tuple[Float[Array, "TL D"], Int[Array, ""]]: """DeepEP dispatch/combine path for an intranode expert mesh.""" - del capacity_factor local_experts = moe_w13_local.shape[0] if num_experts % local_experts != 0: raise ValueError( @@ -119,46 +68,54 @@ def _moe_mlp_ep_deepep_local( raise ValueError(f"DeepEP transport requires hidden % 8 == 0, got hidden={x_local.shape[1]}") ep_size = num_experts // local_experts + topk = selected_experts_local.shape[1] + local_capacity = int(math.ceil(capacity_factor * x_local.shape[0] * topk)) + local_capacity = max(local_experts, local_capacity) max_recv_tokens = x_local.shape[0] * ep_size with jax.named_scope("dispatch"): - with jax.named_scope("deepep_layout"): - num_tokens_per_rank, num_tokens_per_expert, is_token_in_rank = deepep_get_dispatch_layout( - selected_experts_local, - num_ranks=ep_size, - num_experts=num_experts, - ) - with jax.named_scope("deepep_dispatch_transport"): - ( - recv_x, - recv_topk_idx, - recv_topk_weights, - recv_src_idx, - rank_prefix_matrix, - channel_prefix_matrix, - recv_channel_prefix_matrix, - send_head, - _local_expert_counts, - num_recv_tokens, - ) = deepep_dispatch_intranode( - x_local, - selected_experts_local, - combine_weights_local, - num_tokens_per_rank, - num_tokens_per_expert, - is_token_in_rank, - num_experts=num_experts, - max_recv_tokens=max_recv_tokens, - ) - num_recv_tokens_scalar = jnp.squeeze(num_recv_tokens, axis=0) - local_assignments = _pack_deepep_local_assignments( + num_tokens_per_rank, num_tokens_per_expert, is_token_in_rank = deepep_get_dispatch_layout( + selected_experts_local, + num_ranks=ep_size, + num_experts=num_experts, + ) + ( recv_x, - recv_topk_idx, recv_topk_weights, - local_experts=local_experts, - num_recv_tokens=num_recv_tokens_scalar, + recv_src_idx, + rank_prefix_matrix, + channel_prefix_matrix, + recv_channel_prefix_matrix, + send_head, + local_group_sizes, + num_recv_tokens, + x_dispatch, + assignment_weights, + recv_token_indices, + assignment_destinations, + ) = deepep_dispatch_intranode_with_assignments( + x_local, + selected_experts_local, + combine_weights_local, + num_tokens_per_rank, + num_tokens_per_expert, + is_token_in_rank, + num_experts=num_experts, + max_recv_tokens=max_recv_tokens, ) - x_dispatch = tree_checkpoint_name(local_assignments.x_dispatch, _CHECKPOINT_DISPATCH_INPUT) + accepted_group_sizes = _prefix_cap_counts(local_group_sizes, capacity=local_capacity) + accepted_total = jnp.sum(accepted_group_sizes, dtype=jnp.int32) + dropped_local = jnp.sum(local_group_sizes, dtype=jnp.int32) - accepted_total + x_dispatch = x_dispatch[:local_capacity] + assignment_weights = assignment_weights[:local_capacity] + recv_token_indices = recv_token_indices[:local_capacity] + local_assignments = DeepEPLocalAssignments( + x_dispatch, + assignment_weights, + recv_token_indices, + accepted_group_sizes, + ) + x_dispatch = tree_checkpoint_name(x_dispatch, _CHECKPOINT_DISPATCH_INPUT) with jax.named_scope("moe_up_down"): w13_out = tree_checkpoint_name( @@ -172,24 +129,25 @@ def _moe_mlp_ep_deepep_local( ) with jax.named_scope("combine"): - recv_out = _collapse_deepep_local_assignments( + recv_out = deepep_collapse_local_assignments( out_dispatch, local_assignments.assignment_weights, local_assignments.recv_token_indices, + assignment_destinations, + local_assignments.local_group_sizes, + num_recv_tokens, recv_capacity=recv_x.shape[0], - num_recv_tokens=num_recv_tokens_scalar, ) - with jax.named_scope("deepep_combine_transport"): - out_local, _ = deepep_combine_intranode( - recv_out, - recv_topk_weights, - recv_src_idx, - rank_prefix_matrix, - channel_prefix_matrix, - recv_channel_prefix_matrix, - send_head, - num_recv_tokens, - is_token_in_rank, - ) - dropped_total = jnp.array(0, dtype=jnp.int32) + out_local, _ = deepep_combine_intranode( + recv_out, + recv_topk_weights, + recv_src_idx, + rank_prefix_matrix, + channel_prefix_matrix, + recv_channel_prefix_matrix, + send_head, + num_recv_tokens, + is_token_in_rank, + ) + dropped_total = jax.lax.psum(dropped_local, ("data", "expert")) return out_local.astype(x_local.dtype), dropped_total diff --git a/lib/levanter/src/levanter/grug/_moe/ep_ragged_all_to_all.py b/lib/levanter/src/levanter/grug/_moe/ep_ragged_all_to_all.py deleted file mode 100644 index ca29dc5b74..0000000000 --- a/lib/levanter/src/levanter/grug/_moe/ep_ragged_all_to_all.py +++ /dev/null @@ -1,123 +0,0 @@ -# Copyright The Levanter Authors -# SPDX-License-Identifier: Apache-2.0 - -"""Ragged all-to-all expert-parallel Grug MoE backend.""" - -import math -from collections.abc import Callable - -import jax -import jax.numpy as jnp - -from haliax.nn.ragged_dot import ragged_dot -from levanter.grug._moe.ep_common import ( - _clip_receiver_group_sizes, - _compact_by_keep_mask, - _expand_from_keep_mask, - _expert_prefix_keep_mask, - _local_permute_from_counts, - _permute_by_global_expert, - _shard_a2a_params, - _sort_activations, - _unpermute_from_global_expert, -) -from levanter.grug.sharding import _batch_axes - - -def _moe_mlp_ep_ragged_a2a_local( - x_local: jax.Array, - selected_experts_local: jax.Array, - combine_weights_local: jax.Array, - moe_w13_local: jax.Array, - moe_w2_local: jax.Array, - *, - activation_fn: Callable[[jax.Array], jax.Array], - num_experts: int, - capacity_factor: float, -) -> tuple[jax.Array, jax.Array]: - local_experts = moe_w13_local.shape[0] - if num_experts % local_experts != 0: - raise ValueError( - f"num_experts={num_experts} must be divisible by local expert count={local_experts} in EP mode" - ) - - shard_id = jax.lax.axis_index("expert") - ep_size = num_experts // local_experts - tokens_per_shard = x_local.shape[0] - topk = selected_experts_local.shape[1] - assignments_per_shard = tokens_per_shard * topk - local_capacity = int(math.ceil(capacity_factor * assignments_per_shard)) - local_capacity = max(local_experts, local_capacity) - recv_capacity = local_capacity - - with jax.named_scope("dispatch"): - sorted_x, sorted_indices, group_sizes = _permute_by_global_expert( - x_local, - selected_experts_local, - num_experts=num_experts, - ) - all_group_sizes = jax.lax.all_gather(group_sizes.astype(jnp.int32), "expert") - clipped_group_sizes = _clip_receiver_group_sizes( - all_group_sizes, - local_expert_size=local_experts, - receiver_capacity=local_capacity, - ) - sender_group_sizes = clipped_group_sizes[shard_id] - keep_mask = _expert_prefix_keep_mask( - group_sizes.astype(jnp.int32), - sender_group_sizes, - total_size=assignments_per_shard, - ) - sorted_x = _compact_by_keep_mask(sorted_x, keep_mask) - - all_shard_counts = jnp.sum(clipped_group_sizes.reshape(ep_size, ep_size, local_experts), axis=2) - input_offsets, send_sizes, output_offsets, recv_sizes = _shard_a2a_params(all_shard_counts, shard_id) - dispatch_out_shape = jnp.zeros((recv_capacity, x_local.shape[1]), dtype=x_local.dtype) - x_dispatched = jax.lax.ragged_all_to_all( - sorted_x, - dispatch_out_shape, - input_offsets, - send_sizes, - output_offsets, - recv_sizes, - axis_name="expert", - ) - x_dispatch, local_sorted_indices, local_group_sizes = _local_permute_from_counts( - x_dispatched, - clipped_group_sizes, - local_expert_size=local_experts, - shard_index=shard_id, - ) - - with jax.named_scope("moe_up_down"): - w13_out = ragged_dot(x_dispatch, moe_w13_local, local_group_sizes) - moe_dim = moe_w2_local.shape[1] - gate, up = jnp.split(w13_out, [moe_dim], axis=-1) - out_dispatch = ragged_dot(activation_fn(gate) * up, moe_w2_local, local_group_sizes) - - with jax.named_scope("combine"): - local_output = _sort_activations(out_dispatch, jnp.argsort(local_sorted_indices)) - return_out_shape = jnp.zeros((assignments_per_shard, x_local.shape[1]), dtype=local_output.dtype) - return_input_offsets, return_send_sizes, return_output_offsets, return_recv_sizes = _shard_a2a_params( - all_shard_counts.T, shard_id - ) - returned = jax.lax.ragged_all_to_all( - local_output, - return_out_shape, - return_input_offsets, - return_send_sizes, - return_output_offsets, - return_recv_sizes, - axis_name="expert", - ) - returned = _expand_from_keep_mask(returned, keep_mask) - out_local = _unpermute_from_global_expert( - returned, - sorted_indices, - combine_weights_local, - tokens_per_shard=tokens_per_shard, - topk=topk, - ).astype(x_local.dtype) - dropped_local = jnp.sum(group_sizes, dtype=jnp.int32) - jnp.sum(sender_group_sizes, dtype=jnp.int32) - dropped_total = jax.lax.psum(dropped_local, _batch_axes(jax.sharding.get_abstract_mesh())) - return out_local, dropped_total diff --git a/lib/levanter/src/levanter/grug/attention/_fa4_cute.py b/lib/levanter/src/levanter/grug/attention/_fa4_cute.py index c8c67129de..b686e0d4fe 100644 --- a/lib/levanter/src/levanter/grug/attention/_fa4_cute.py +++ b/lib/levanter/src/levanter/grug/attention/_fa4_cute.py @@ -76,6 +76,25 @@ def _packed_segment_causal_lower_bounds( return jnp.where(valid, lower_bounds, seq_len), valid +def _simple_causal_lower_bounds( + *, + batch_size: int, + seq_len: int, + sliding_window: int | None, +) -> tuple[Int[Array, "B S"], Bool[Array, "B S"]]: + if sliding_window is not None and sliding_window <= 0: + raise ValueError(f"sliding_window must be positive, got {sliding_window}") + + positions = jnp.arange(seq_len, dtype=jnp.int32)[None, :] + if sliding_window is None: + lower_bounds = jnp.zeros((1, seq_len), dtype=jnp.int32) + else: + lower_bounds = jnp.maximum(positions - (sliding_window - 1), 0) + lower_bounds = jnp.broadcast_to(lower_bounds, (batch_size, seq_len)) + valid = jnp.ones((batch_size, seq_len), dtype=jnp.bool_) + return lower_bounds, valid + + def _packed_self_attention_segment_ids( q: jax.Array, k: jax.Array, @@ -109,6 +128,40 @@ def _packed_self_attention_segment_ids( return q_segment_ids +def _self_attention_lower_bounds( + q: jax.Array, + k: jax.Array, + mask: AttentionMask | Bool[Array, "B Q K"] | Float[Array, "B Q K"] | None, + *, + backend_name: str, +) -> tuple[Int[Array, "B S"], Bool[Array, "B S"]]: + if isinstance(mask, jax.Array): + raise NotImplementedError(f"{backend_name} does not support dense masks.") + if not isinstance(mask, AttentionMask): + raise NotImplementedError(f"{backend_name} requires an AttentionMask.") + if not mask.is_causal: + raise NotImplementedError(f"{backend_name} currently supports only causal self-attention.") + if q.shape[0] != k.shape[0]: + raise NotImplementedError(f"{backend_name} requires matching q/kv batch sizes.") + if q.shape[1] != k.shape[1]: + raise NotImplementedError(f"{backend_name} requires self-attention q_len == k_len.") + + if mask.segment_ids is None: + return _simple_causal_lower_bounds( + batch_size=q.shape[0], + seq_len=q.shape[1], + sliding_window=mask.sliding_window, + ) + + q_segment_ids = _packed_self_attention_segment_ids(q, k, mask, backend_name=backend_name) + return _packed_segment_causal_lower_bounds( + q_segment_ids, + batch_size=q.shape[0], + seq_len=q.shape[1], + sliding_window=mask.sliding_window, + ) + + def _validate_head_layout(q: jax.Array, k: jax.Array, *, backend_name: str) -> None: if q.shape[2] % k.shape[2] != 0: raise ValueError(f"{backend_name} requires Hq divisible by Hkv, got q={q.shape}, k={k.shape}") @@ -153,13 +206,11 @@ def gpu_fa4_cute_attention( raise RuntimeError("gpu_fa4_cute_attention requires the JAX GPU backend.") _validate_head_layout(q, k, backend_name="gpu_fa4_cute_attention") - q_segment_ids = _packed_self_attention_segment_ids(q, k, mask, backend_name="gpu_fa4_cute_attention") - assert isinstance(mask, AttentionMask) - lower_bounds, valid = _packed_segment_causal_lower_bounds( - q_segment_ids, - batch_size=q.shape[0], - seq_len=q.shape[1], - sliding_window=mask.sliding_window, + lower_bounds, valid = _self_attention_lower_bounds( + q, + k, + mask, + backend_name="gpu_fa4_cute_attention", ) kernel_config = _segmented_kernel_config(q.shape[-1]) diff --git a/lib/levanter/src/levanter/grug/attention/_fa4_thd.py b/lib/levanter/src/levanter/grug/attention/_fa4_thd.py index 4a087327ae..873f1a3692 100644 --- a/lib/levanter/src/levanter/grug/attention/_fa4_thd.py +++ b/lib/levanter/src/levanter/grug/attention/_fa4_thd.py @@ -88,7 +88,8 @@ def _validate_simple_causal_self_attention( if not mask.is_causal: raise NotImplementedError(f"{backend_name} supports only causal self-attention.") if mask.sliding_window is not None: - raise NotImplementedError(f"{backend_name} does not support sliding-window attention.") + if mask.sliding_window <= 0: + raise ValueError(f"sliding_window must be positive, got {mask.sliding_window}") if len(q.shape) != 4 or len(k.shape) != 4 or len(v.shape) != 4: raise ValueError( @@ -218,6 +219,7 @@ def _upstream_fa4_thd_forward_launcher( head_dim_v: int, qhead_per_kvhead: int, kernel_config: Flash4CuteKernelConfig, + sliding_window: int | None, ) -> Any: cutlass = modules.cutlass cute = modules.cute @@ -226,8 +228,8 @@ def _upstream_fa4_thd_forward_launcher( head_dim, head_dim_v, qhead_per_kvhead=qhead_per_kvhead, - is_causal=True, - is_local=False, + is_causal=sliding_window is None, + is_local=sliding_window is not None, is_split_kv=False, pack_gqa=qhead_per_kvhead > 1, m_block_size=kernel_config.forward_tile[0], @@ -268,8 +270,8 @@ def _launch_upstream_fa4_thd_forward( None, None, None, - None, - None, + None if sliding_window is None else sliding_window - 1, + None if sliding_window is None else 0, None, None, None, @@ -288,6 +290,7 @@ def _upstream_fa4_thd_backward_launcher( head_dim_v: int, qhead_per_kvhead: int, kernel_config: Flash4CuteKernelConfig, + sliding_window: int | None, ) -> Any: cutlass = modules.cutlass cute = modules.cute @@ -311,8 +314,8 @@ def _upstream_fa4_thd_backward_launcher( backward = modules.FlashAttentionBackwardSm100( head_dim, head_dim_v, - is_causal=True, - is_local=False, + is_causal=sliding_window is None, + is_local=sliding_window is not None, qhead_per_kvhead=qhead_per_kvhead, tile_m=tile_m, tile_n=tile_n, @@ -417,8 +420,8 @@ def _launch_upstream_fa4_thd_backward( cu_seqlens, None, None, - None, - None, + None if sliding_window is None else sliding_window - 1, + None if sliding_window is None else 0, None, None, None, @@ -441,6 +444,7 @@ def fa4_thd_attention_forward( *, softmax_scale: float, kernel_config: Flash4CuteKernelConfig, + sliding_window: int | None, ) -> tuple[Float[Array, "T Hq D"], Float[Array, "Hq T"]]: _validate_thd_inputs(q, k, v, cu_seqlens, softmax_scale=softmax_scale) try: @@ -454,6 +458,7 @@ def fa4_thd_attention_forward( head_dim_v=v.shape[-1], qhead_per_kvhead=q.shape[1] // k.shape[1], kernel_config=kernel_config, + sliding_window=sliding_window, ) input_spec, output_spec = _cutlass_thd_forward_specs(modules) out_shape_dtype = jax.ShapeDtypeStruct(q.shape, q.dtype) @@ -480,6 +485,7 @@ def fa4_thd_attention_backward( *, softmax_scale: float, kernel_config: Flash4CuteKernelConfig, + sliding_window: int | None, ) -> tuple[Float[Array, "T Hq D"], Float[Array, "T Hkv D"], Float[Array, "T Hkv D"]]: _validate_thd_inputs(q, k, v, cu_seqlens, softmax_scale=softmax_scale) if q.shape[1] == k.shape[1]: @@ -496,6 +502,7 @@ def fa4_thd_attention_backward( head_dim_v=v.shape[-1], qhead_per_kvhead=q.shape[1] // k.shape[1], kernel_config=kernel_config, + sliding_window=sliding_window, ) input_spec, output_spec = _cutlass_thd_backward_specs(modules) output_shape_dtype = _cutlass_thd_backward_output_shapes(q, k, v, cu_seqlens, kernel_config.backward_tile) @@ -563,7 +570,7 @@ def _num_thd_sequences(*, cu_seqlens_shape: int, cu_seqlens_rank: int) -> int: return cu_seqlens_shape - 1 -@partial(jax.custom_vjp, nondiff_argnums=(4, 5)) +@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6)) def _jax_fa4_thd_attention( q: Float[Array, "T Hq D"], k: Float[Array, "T Hkv D"], @@ -571,6 +578,7 @@ def _jax_fa4_thd_attention( cu_seqlens: Int[Array, "N"], softmax_scale: float, kernel_config: Flash4CuteKernelConfig, + sliding_window: int | None, ) -> Float[Array, "T Hq D"]: out, _ = fa4_thd_attention_forward( q, @@ -579,6 +587,7 @@ def _jax_fa4_thd_attention( cu_seqlens, softmax_scale=softmax_scale, kernel_config=kernel_config, + sliding_window=sliding_window, ) return out @@ -590,6 +599,7 @@ def _jax_fa4_thd_attention_fwd( cu_seqlens: Int[Array, "N"], softmax_scale: float, kernel_config: Flash4CuteKernelConfig, + sliding_window: int | None, ) -> tuple[ Float[Array, "T Hq D"], tuple[ @@ -608,6 +618,7 @@ def _jax_fa4_thd_attention_fwd( cu_seqlens, softmax_scale=softmax_scale, kernel_config=kernel_config, + sliding_window=sliding_window, ) return out, (q, k, v, out, lse, cu_seqlens) @@ -615,6 +626,7 @@ def _jax_fa4_thd_attention_fwd( def _jax_fa4_thd_attention_bwd( softmax_scale: float, kernel_config: Flash4CuteKernelConfig, + sliding_window: int | None, residuals: tuple[ Float[Array, "T Hq D"], Float[Array, "T Hkv D"], @@ -638,6 +650,7 @@ def _jax_fa4_thd_attention_bwd( cu_seqlens, softmax_scale=softmax_scale, kernel_config=kernel_config, + sliding_window=sliding_window, ) return dq, dk, dv, None @@ -708,6 +721,7 @@ def gpu_fa4_thd_attention( cu_seqlens, 1.0 / math.sqrt(head_dim), kernel_config, + mask.sliding_window, ) return out.reshape(batch_size, seq_len, q.shape[2], head_dim) diff --git a/lib/levanter/src/levanter/grug/grug_moe.py b/lib/levanter/src/levanter/grug/grug_moe.py index 312be458aa..1a293b8cde 100644 --- a/lib/levanter/src/levanter/grug/grug_moe.py +++ b/lib/levanter/src/levanter/grug/grug_moe.py @@ -7,9 +7,9 @@ - Routing keeps the argsort-grouped dispatch path that emerged as the stable default from https://github.com/marin-community/marin/issues/2704 and commit 89318a910 (and its parent). -- Expert parallelism keeps the ring-style strategy from - https://github.com/marin-community/marin/issues/2710: token-sharded - `all_gather` for dispatch, then `psum_scatter` for collection. +- Expert parallelism keeps ring as a comparator, a plain-XLA assigned-token + reference path, and a DeepEP-backed assigned-token transport path for GPU + intranode runs. - Backend bodies live in the private `levanter.grug._moe` package; this module keeps the stable public API used by Grug model code and benchmarks. """ @@ -43,7 +43,7 @@ _shard_a2a_params as _shard_a2a_params, ) from levanter.grug._moe.ep_deepep import _moe_mlp_ep_deepep_local -from levanter.grug._moe.ep_ragged_all_to_all import _moe_mlp_ep_ragged_a2a_local +from levanter.grug._moe.ep_assigned_token import _moe_mlp_ep_assigned_token_local from levanter.grug._moe.ep_ring import _moe_mlp_ep_ring_local from levanter.grug._moe.local import _moe_mlp_local from levanter.grug.sharding import ( @@ -207,8 +207,8 @@ def moe_mlp( if resolved_implementation == "ring": shard_local_fn = _moe_mlp_ep_ring_local - elif resolved_implementation == "ragged_all_to_all": - shard_local_fn = _moe_mlp_ep_ragged_a2a_local + elif resolved_implementation == "assigned_token": + shard_local_fn = _moe_mlp_ep_assigned_token_local elif resolved_implementation == "deepep": shard_local_fn = _moe_mlp_ep_deepep_local else: diff --git a/lib/levanter/src/levanter/kernels/deepep/__init__.py b/lib/levanter/src/levanter/kernels/deepep/__init__.py index af1c4bdd98..6ad777db5c 100644 --- a/lib/levanter/src/levanter/kernels/deepep/__init__.py +++ b/lib/levanter/src/levanter/kernels/deepep/__init__.py @@ -3,23 +3,15 @@ """DeepEP-backed JAX kernel helpers.""" -from .availability import deepep_install_help, deepep_preflight_status -from .layout_ffi import deepep_get_dispatch_layout +from .availability import deepep_install_help as deepep_install_help +from .availability import deepep_preflight_status as deepep_preflight_status +from .layout_ffi import deepep_get_dispatch_layout as deepep_get_dispatch_layout from .transport_ffi import ( - deepep_combine_intranode, - deepep_dispatch_intranode, - ensure_intranode_runtime, - run_host_dispatch_round, - shutdown_intranode_runtime, + deepep_collapse_local_assignments as deepep_collapse_local_assignments, + deepep_combine_intranode as deepep_combine_intranode, + deepep_dispatch_intranode as deepep_dispatch_intranode, + deepep_dispatch_intranode_with_assignments as deepep_dispatch_intranode_with_assignments, + ensure_intranode_runtime as ensure_intranode_runtime, + run_host_dispatch_round as run_host_dispatch_round, + shutdown_intranode_runtime as shutdown_intranode_runtime, ) - -__all__ = [ - "deepep_combine_intranode", - "deepep_dispatch_intranode", - "deepep_get_dispatch_layout", - "deepep_install_help", - "deepep_preflight_status", - "ensure_intranode_runtime", - "run_host_dispatch_round", - "shutdown_intranode_runtime", -] diff --git a/lib/levanter/src/levanter/kernels/deepep/csrc/deepep_transport_ffi.cu b/lib/levanter/src/levanter/kernels/deepep/csrc/deepep_transport_ffi.cu index a77df92875..588a8120f0 100644 --- a/lib/levanter/src/levanter/kernels/deepep/csrc/deepep_transport_ffi.cu +++ b/lib/levanter/src/levanter/kernels/deepep/csrc/deepep_transport_ffi.cu @@ -16,12 +16,58 @@ #include #include +#include #include #include "config.hpp" #include "kernels/api.cuh" #include "xla/ffi/api/ffi.h" +namespace deep_ep::intranode { + +void dispatch_assignments( + void* recv_x, + float* recv_x_scales, + float* recv_x_sf_scale_for_nvfp4, + int* recv_src_idx, + int64_t* recv_topk_idx, + float* recv_topk_weights, + int* recv_channel_offset, + void* x_dispatch, + nv_bfloat16* assignment_weights, + int* recv_token_indices, + int* local_group_cursors, + int* recv_assignment_indices, + int* assignment_destinations, + int* send_head, + const void* x, + const float* x_scales, + const float* sf_scale_for_nvfp4, + const int64_t* topk_idx, + const float* topk_weights, + const bool* is_token_in_rank, + const int* channel_prefix_matrix, + int num_tokens, + int num_worst_tokens, + int hidden_int4, + int num_topk, + int num_experts, + int num_scales, + int num_sf_scales_for_nvfp4, + int scale_token_stride, + int scale_hidden_stride, + int sf_scale_for_nvfp4_token_stride, + int sf_scale_for_nvfp4_hidden_stride, + void** buffer_ptrs, + int rank, + int num_ranks, + cudaStream_t stream, + int num_sms, + int num_max_send_tokens, + int num_recv_buffer_tokens); + +} // namespace deep_ep::intranode + namespace ffi = xla::ffi; namespace { @@ -155,6 +201,17 @@ int ReadDeviceScalarInt(cudaStream_t stream, const int* value, const char* conte return host_value; } +int ReadRecvCount(DeviceRuntime& runtime, cudaStream_t stream, const int* value, int recv_capacity, const char* context) { + const int host_value = static_cast(*runtime.moe_recv_counter); + if (host_value >= 0) { + if (host_value > recv_capacity) { + throw std::runtime_error("DeepEP intranode receive count exceeds receive buffer capacity"); + } + return host_value; + } + return ReadDeviceScalarInt(stream, value, context); +} + __global__ void CastInt32ToInt64Kernel(const int* src, int64_t* dst, size_t count) { const size_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; if (idx < count) { @@ -183,6 +240,331 @@ void LaunchCastInt64ToInt32(const int64_t* src, int* dst, size_t count, cudaStre ThrowOnCuda(cudaGetLastError(), "CastInt64ToInt32Kernel"); } +__global__ void CountLocalAssignmentsKernel( + const int* recv_topk_idx, + const int* num_recv_tokens, + int* local_group_sizes, + int recv_capacity, + int num_topk, + int local_experts) { + const int assignment = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + const int total_assignments = recv_capacity * num_topk; + if (assignment >= total_assignments) { + return; + } + const int token = assignment / num_topk; + if (token >= num_recv_tokens[0]) { + return; + } + const int local_expert = recv_topk_idx[assignment]; + if (local_expert >= 0 && local_expert < local_experts) { + atomicAdd(&local_group_sizes[local_expert], 1); + } +} + +__global__ void PrefixLocalAssignmentCursorsKernel( + const int* local_group_sizes, + int* local_group_cursors, + int local_experts) { + if (blockIdx.x != 0 || threadIdx.x != 0) { + return; + } + int prefix = 0; + for (int expert = 0; expert < local_experts; ++expert) { + const int group_size = local_group_sizes[expert]; + local_group_cursors[expert] = prefix; + prefix += group_size; + } +} + +template +__global__ void AssignLocalAssignmentDestinationsKernel( + const ExpertIndexT* recv_topk_idx, + const float* recv_topk_weights, + const int* num_recv_tokens, + int* local_group_cursors, + nv_bfloat16* assignment_weights, + int* recv_token_indices, + int* recv_assignment_indices, + int* assignment_destinations, + int recv_capacity, + int num_topk, + int local_experts) { + const int assignment = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + const int total_assignments = recv_capacity * num_topk; + if (assignment >= total_assignments) { + return; + } + const int token = assignment / num_topk; + if (token >= num_recv_tokens[0]) { + return; + } + const int local_expert = static_cast(recv_topk_idx[assignment]); + if (local_expert < 0 || local_expert >= local_experts) { + return; + } + + const int destination = atomicAdd(&local_group_cursors[local_expert], 1); + assignment_destinations[assignment] = destination; + recv_token_indices[destination] = token; + recv_assignment_indices[destination] = assignment; + assignment_weights[destination] = __float2bfloat16(recv_topk_weights[assignment]); +} + +__global__ void PackLocalAssignmentRowsKernel( + const nv_bfloat16* recv_x, + const int* recv_token_indices, + nv_bfloat16* x_dispatch, + int total_valid_assignments, + int hidden, + int hidden_int4) { + const size_t element = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + const size_t total_elements = static_cast(total_valid_assignments) * hidden_int4; + if (element >= total_elements) { + return; + } + const int col = static_cast(element % hidden_int4); + const int destination = static_cast(element / hidden_int4); + const int token = recv_token_indices[destination]; + const int4* src_row = reinterpret_cast(recv_x + static_cast(token) * hidden); + int4* dst_row = reinterpret_cast(x_dispatch + static_cast(destination) * hidden); + dst_row[col] = src_row[col]; +} + +__global__ void CollapseLocalAssignmentsKernel( + const nv_bfloat16* out_dispatch, + const nv_bfloat16* assignment_weights, + const int* assignment_destinations, + const int* accepted_total_assignments, + const int* num_recv_tokens, + nv_bfloat16* recv_out, + int recv_capacity, + int num_topk, + int hidden) { + const size_t element = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + const size_t total_elements = static_cast(recv_capacity) * hidden; + if (element >= total_elements) { + return; + } + const int recv_token = static_cast(element / hidden); + const int col = static_cast(element - static_cast(recv_token) * hidden); + if (recv_token >= num_recv_tokens[0]) { + recv_out[element] = __float2bfloat16(0.0f); + return; + } + + const int total_valid_assignments = accepted_total_assignments[0]; + + float value = 0.0f; + for (int topk = 0; topk < num_topk; ++topk) { + const int assignment = recv_token * num_topk + topk; + const int destination = assignment_destinations[assignment]; + if (destination < 0 || destination >= total_valid_assignments) { + continue; + } + value += __bfloat162float(out_dispatch[static_cast(destination) * hidden + col]) * + __bfloat162float(assignment_weights[destination]); + } + recv_out[element] = __float2bfloat16(value); +} + +void LaunchPackLocalAssignments( + const nv_bfloat16* recv_x, + const int* recv_topk_idx, + const float* recv_topk_weights, + const int* num_recv_tokens, + nv_bfloat16* x_dispatch, + nv_bfloat16* assignment_weights, + int* recv_token_indices, + int* local_group_sizes, + int* local_group_cursors, + int* recv_assignment_indices, + int* assignment_destinations, + int recv_capacity, + int hidden, + int num_topk, + int local_experts, + cudaStream_t stream) { + constexpr int kThreads = 256; + const int total_assignments = recv_capacity * num_topk; + ThrowOnCuda( + cudaMemsetAsync( + assignment_weights, + 0, + static_cast(total_assignments) * sizeof(nv_bfloat16), + stream), + "cudaMemsetAsync(pack assignment_weights)"); + ThrowOnCuda( + cudaMemsetAsync( + recv_token_indices, + 0, + static_cast(total_assignments) * sizeof(int), + stream), + "cudaMemsetAsync(pack recv_token_indices)"); + ThrowOnCuda( + cudaMemsetAsync( + recv_assignment_indices, + 0, + static_cast(total_assignments) * sizeof(int), + stream), + "cudaMemsetAsync(pack recv_assignment_indices)"); + ThrowOnCuda( + cudaMemsetAsync( + assignment_destinations, + 0xff, + static_cast(total_assignments) * sizeof(int), + stream), + "cudaMemsetAsync(pack assignment_destinations)"); + ThrowOnCuda( + cudaMemsetAsync(local_group_sizes, 0, static_cast(local_experts) * sizeof(int), stream), + "cudaMemsetAsync(pack local_group_sizes)"); + ThrowOnCuda( + cudaMemsetAsync(local_group_cursors, 0, static_cast(local_experts) * sizeof(int), stream), + "cudaMemsetAsync(pack local_group_cursors)"); + + const int assignment_blocks = (total_assignments + kThreads - 1) / kThreads; + CountLocalAssignmentsKernel<<>>( + recv_topk_idx, + num_recv_tokens, + local_group_sizes, + recv_capacity, + num_topk, + local_experts); + ThrowOnCuda(cudaGetLastError(), "CountLocalAssignmentsKernel"); + PrefixLocalAssignmentCursorsKernel<<<1, 1, 0, stream>>>(local_group_sizes, local_group_cursors, local_experts); + ThrowOnCuda(cudaGetLastError(), "PrefixLocalAssignmentCursorsKernel"); + AssignLocalAssignmentDestinationsKernel<<>>( + recv_topk_idx, + recv_topk_weights, + num_recv_tokens, + local_group_cursors, + assignment_weights, + recv_token_indices, + recv_assignment_indices, + assignment_destinations, + recv_capacity, + num_topk, + local_experts); + ThrowOnCuda(cudaGetLastError(), "AssignLocalAssignmentDestinationsKernel"); + + const int hidden_int4 = hidden * static_cast(sizeof(nv_bfloat16)) / static_cast(sizeof(int4)); + const size_t copy_elements = static_cast(total_assignments) * hidden_int4; + const int copy_blocks = static_cast((copy_elements + kThreads - 1) / kThreads); + PackLocalAssignmentRowsKernel<<>>( + recv_x, + recv_token_indices, + x_dispatch, + total_assignments, + hidden, + hidden_int4); + ThrowOnCuda(cudaGetLastError(), "PackLocalAssignmentRowsKernel"); +} + +template +void LaunchPackLocalAssignmentsFromCounts( + const nv_bfloat16* recv_x, + const ExpertIndexT* recv_topk_idx, + const float* recv_topk_weights, + const int* num_recv_tokens, + const int* local_group_sizes, + nv_bfloat16* x_dispatch, + nv_bfloat16* assignment_weights, + int* recv_token_indices, + int* local_group_cursors, + int* recv_assignment_indices, + int* assignment_destinations, + int recv_capacity, + int hidden, + int num_topk, + int local_experts, + int active_recv_tokens, + int total_valid_assignments, + cudaStream_t stream) { + constexpr int kThreads = 256; + const int total_assignments = recv_capacity * num_topk; + if (active_recv_tokens < 0 || active_recv_tokens > recv_capacity) { + throw std::runtime_error("DeepEP count-seeded pack active receive token count is out of range"); + } + if (total_valid_assignments < 0 || total_valid_assignments > total_assignments) { + throw std::runtime_error("DeepEP count-seeded pack active assignment count is out of range"); + } + const int active_assignments = active_recv_tokens * num_topk; + ThrowOnCuda( + cudaMemsetAsync( + assignment_destinations, + 0xff, + static_cast(active_assignments) * sizeof(int), + stream), + "cudaMemsetAsync(pack-counts assignment_destinations)"); + ThrowOnCuda( + cudaMemsetAsync(local_group_cursors, 0, static_cast(local_experts) * sizeof(int), stream), + "cudaMemsetAsync(pack-counts local_group_cursors)"); + + PrefixLocalAssignmentCursorsKernel<<<1, 1, 0, stream>>>(local_group_sizes, local_group_cursors, local_experts); + ThrowOnCuda(cudaGetLastError(), "PrefixLocalAssignmentCursorsKernel(counts)"); + + const int assignment_blocks = (active_assignments + kThreads - 1) / kThreads; + if (assignment_blocks > 0) { + AssignLocalAssignmentDestinationsKernel<<>>( + recv_topk_idx, + recv_topk_weights, + num_recv_tokens, + local_group_cursors, + assignment_weights, + recv_token_indices, + recv_assignment_indices, + assignment_destinations, + active_recv_tokens, + num_topk, + local_experts); + ThrowOnCuda(cudaGetLastError(), "AssignLocalAssignmentDestinationsKernel(counts)"); + } + + const int hidden_int4 = hidden * static_cast(sizeof(nv_bfloat16)) / static_cast(sizeof(int4)); + const size_t copy_elements = static_cast(total_valid_assignments) * hidden_int4; + const int copy_blocks = static_cast((copy_elements + kThreads - 1) / kThreads); + if (copy_blocks > 0) { + PackLocalAssignmentRowsKernel<<>>( + recv_x, + recv_token_indices, + x_dispatch, + total_valid_assignments, + hidden, + hidden_int4); + ThrowOnCuda(cudaGetLastError(), "PackLocalAssignmentRowsKernel(counts)"); + } +} + +void LaunchCollapseLocalAssignments( + const nv_bfloat16* out_dispatch, + const nv_bfloat16* assignment_weights, + const int* assignment_destinations, + const int* accepted_total_assignments, + const int* num_recv_tokens, + nv_bfloat16* recv_out, + int recv_capacity, + int active_recv_tokens, + int num_topk, + int hidden, + cudaStream_t stream) { + constexpr int kThreads = 256; + const size_t recv_elements = static_cast(active_recv_tokens) * hidden; + const int assignment_blocks = static_cast((recv_elements + kThreads - 1) / kThreads); + if (assignment_blocks > 0) { + CollapseLocalAssignmentsKernel<<>>( + out_dispatch, + assignment_weights, + assignment_destinations, + accepted_total_assignments, + num_recv_tokens, + recv_out, + recv_capacity, + num_topk, + hidden); + ThrowOnCuda(cudaGetLastError(), "CollapseLocalAssignmentsKernel"); + } +} + void EnablePeerAccess(int peer_device_id) { cudaError_t status = cudaDeviceEnablePeerAccess(peer_device_id, 0); if (status == cudaSuccess) { @@ -570,6 +952,16 @@ void DispatchOnCurrentDevice( if (num_recv_tokens_host_out != nullptr) { *num_recv_tokens_host_out = num_recv_tokens; } + if (num_recv_tokens_device_out != nullptr) { + ThrowOnCuda( + cudaMemcpyAsync( + num_recv_tokens_device_out, + &num_recv_tokens, + sizeof(int), + cudaMemcpyHostToDevice, + stream), + "cudaMemcpyAsync(num_recv_tokens_device)"); + } } else { if (num_recv_tokens_device_out == nullptr) { throw std::runtime_error("DeepEP intranode async dispatch requires a device receive-count output"); @@ -678,6 +1070,203 @@ void DispatchOnCurrentDevice( num_recv_tokens); } +void DispatchAssignmentsOnCurrentDevice( + DeviceRuntime& runtime, + cudaStream_t stream, + const nv_bfloat16* x, + const int64_t* topk_idx, + const float* topk_weights, + const int* num_tokens_per_rank, + const int* num_tokens_per_expert, + const bool* is_token_in_rank, + int num_tokens, + int hidden, + int num_topk, + int num_experts, + nv_bfloat16* recv_x, + int64_t* recv_topk_idx, + float* recv_topk_weights, + int* recv_src_idx, + int* rank_prefix_matrix, + int* channel_prefix_matrix, + int* recv_channel_prefix_matrix, + int* send_head, + int* local_expert_counts, + int* num_recv_tokens_host_out, + int* num_recv_tokens_device_out, + nv_bfloat16* x_dispatch, + nv_bfloat16* assignment_weights, + int* recv_token_indices, + int* local_group_cursors, + int* recv_assignment_indices, + int* assignment_destinations, + int max_recv_tokens) { + if (hidden <= 0 || (hidden * static_cast(sizeof(nv_bfloat16))) % sizeof(int4) != 0) { + throw std::runtime_error("DeepEP assignment dispatch requires hidden*element_size divisible by int4"); + } + if (num_experts % runtime.num_ranks != 0) { + throw std::runtime_error("DeepEP assignment dispatch requires num_experts divisible by num_ranks"); + } + const int num_local_experts = num_experts / runtime.num_ranks; + + ResetRecvCounters(runtime, num_local_experts); + LogHostDispatchStage(runtime.rank, "before_notify_assignment_dispatch", num_tokens, hidden, num_experts, num_topk); + const int num_memset_int = runtime.dispatch_num_channels() * runtime.num_ranks * 4; + deep_ep::intranode::notify_dispatch( + num_tokens_per_rank, + runtime.moe_recv_counter_mapped, + runtime.num_ranks, + num_tokens_per_expert, + runtime.moe_recv_expert_counter_mapped, + num_experts, + num_tokens, + is_token_in_rank, + channel_prefix_matrix, + rank_prefix_matrix, + num_memset_int, + 1, + runtime.buffer_ptrs_gpu, + runtime.barrier_signal_ptrs_gpu, + runtime.rank, + stream, + runtime.dispatch_num_channels()); + + int num_recv_tokens = max_recv_tokens; + WaitForRecvCounts(runtime, num_local_experts, &num_recv_tokens); + LogHostDispatchStage( + runtime.rank, + "after_wait_for_assignment_recv_counts", + num_tokens, + hidden, + num_experts, + num_topk, + num_recv_tokens); + if (num_recv_tokens > max_recv_tokens) { + throw std::runtime_error("DeepEP assignment dispatch recv buffer is smaller than actual recv tokens"); + } + + int total_local_assignments = 0; + for (int expert = 0; expert < num_local_experts; ++expert) { + total_local_assignments += static_cast(runtime.moe_recv_expert_counter[expert]); + } + const int active_assignments = num_recv_tokens * num_topk; + const int total_assignments = max_recv_tokens * num_topk; + if (active_assignments > total_assignments || total_local_assignments > total_assignments) { + throw std::runtime_error("DeepEP assignment dispatch assignment count exceeds output capacity"); + } + + ThrowOnCuda( + cudaMemcpyAsync( + local_expert_counts, + const_cast(runtime.moe_recv_expert_counter), + sizeof(int) * num_local_experts, + cudaMemcpyHostToDevice, + stream), + "cudaMemcpyAsync(assignment local_expert_counts)"); + if (num_recv_tokens_host_out != nullptr) { + *num_recv_tokens_host_out = num_recv_tokens; + } + if (num_recv_tokens_device_out != nullptr) { + ThrowOnCuda( + cudaMemcpyAsync( + num_recv_tokens_device_out, + &num_recv_tokens, + sizeof(int), + cudaMemcpyHostToDevice, + stream), + "cudaMemcpyAsync(assignment num_recv_tokens_device)"); + } + if (active_assignments > 0) { + ThrowOnCuda( + cudaMemsetAsync( + assignment_destinations, + 0xff, + static_cast(active_assignments) * sizeof(int), + stream), + "cudaMemsetAsync(assignment destinations)"); + } + if (total_local_assignments > 0) { + ThrowOnCuda( + cudaMemsetAsync( + assignment_weights, + 0, + static_cast(total_local_assignments) * sizeof(nv_bfloat16), + stream), + "cudaMemsetAsync(assignment weights)"); + ThrowOnCuda( + cudaMemsetAsync( + recv_token_indices, + 0, + static_cast(total_local_assignments) * sizeof(int), + stream), + "cudaMemsetAsync(assignment recv_token_indices)"); + ThrowOnCuda( + cudaMemsetAsync( + recv_assignment_indices, + 0, + static_cast(total_local_assignments) * sizeof(int), + stream), + "cudaMemsetAsync(assignment recv_assignment_indices)"); + } + ThrowOnCuda( + cudaMemsetAsync(local_group_cursors, 0, sizeof(int) * num_local_experts, stream), + "cudaMemsetAsync(assignment local_group_cursors)"); + PrefixLocalAssignmentCursorsKernel<<<1, 1, 0, stream>>>( + local_expert_counts, + local_group_cursors, + num_local_experts); + ThrowOnCuda(cudaGetLastError(), "PrefixLocalAssignmentCursorsKernel(assignment)"); + + deep_ep::intranode::dispatch_assignments( + recv_x, + nullptr, + nullptr, + recv_src_idx, + recv_topk_idx, + recv_topk_weights, + recv_channel_prefix_matrix, + x_dispatch, + assignment_weights, + recv_token_indices, + local_group_cursors, + recv_assignment_indices, + assignment_destinations, + send_head, + x, + nullptr, + nullptr, + topk_idx, + topk_weights, + is_token_in_rank, + channel_prefix_matrix, + num_tokens, + 0, + hidden * static_cast(sizeof(nv_bfloat16)) / sizeof(int4), + num_topk, + num_experts, + 0, + 0, + 0, + 0, + 0, + 0, + runtime.buffer_ptrs_gpu, + runtime.rank, + runtime.num_ranks, + stream, + runtime.dispatch_config.num_sms, + runtime.dispatch_config.num_max_send_tokens, + runtime.dispatch_config.num_max_recv_tokens); + LogHostDispatchStage( + runtime.rank, + "after_assignment_dispatch_launch", + num_tokens, + hidden, + num_experts, + num_topk, + num_recv_tokens); +} + ffi::Error DispatchIntranode( cudaStream_t stream, ffi::Buffer x, @@ -815,6 +1404,161 @@ ffi::Error DispatchIntranode( } } +ffi::Error DispatchIntranodeWithAssignments( + cudaStream_t stream, + ffi::Buffer x, + ffi::Buffer topk_idx, + ffi::Buffer topk_weights, + ffi::Buffer num_tokens_per_rank, + ffi::Buffer num_tokens_per_expert, + ffi::Buffer is_token_in_rank, + int32_t num_experts, + ffi::Result> recv_x, + ffi::Result> recv_topk_weights, + ffi::Result> recv_src_idx, + ffi::Result> rank_prefix_matrix, + ffi::Result> channel_prefix_matrix, + ffi::Result> recv_channel_prefix_matrix, + ffi::Result> send_head, + ffi::Result> local_expert_counts, + ffi::Result> num_recv_tokens_buffer, + ffi::Result> topk_idx_s64_scratch, + ffi::Result> recv_topk_idx_s64_scratch, + ffi::Result> x_dispatch, + ffi::Result> assignment_weights, + ffi::Result> recv_token_indices, + ffi::Result> local_group_cursors, + ffi::Result> recv_assignment_indices, + ffi::Result> assignment_destinations) { + try { + DeviceRuntime& runtime = RuntimeManager::Instance().RuntimeForCurrentDevice(); + const auto x_dims = x.dimensions(); + const auto topk_dims = topk_idx.dimensions(); + const auto rank_dims = num_tokens_per_rank.dimensions(); + const auto expert_dims = num_tokens_per_expert.dimensions(); + const auto token_rank_dims = is_token_in_rank.dimensions(); + const auto topk_scratch_dims = topk_idx_s64_scratch->dimensions(); + const auto recv_topk_scratch_dims = recv_topk_idx_s64_scratch->dimensions(); + if (x_dims.size() != 2 || topk_dims.size() != 2) { + return ffi::Error::InvalidArgument("DeepEP fused dispatch expects rank-2 x and topk_idx"); + } + if (rank_dims.size() != 1 || expert_dims.size() != 1 || token_rank_dims.size() != 2) { + return ffi::Error::InvalidArgument("DeepEP fused dispatch metadata ranks are invalid"); + } + const int num_tokens = static_cast(x_dims[0]); + const int hidden = static_cast(x_dims[1]); + const int num_topk = static_cast(topk_dims[1]); + if (topk_dims[0] != num_tokens || topk_weights.dimensions()[0] != num_tokens || + topk_weights.dimensions()[1] != num_topk) { + return ffi::Error::InvalidArgument("DeepEP fused dispatch top-k tensors must match x"); + } + if (rank_dims[0] != runtime.num_ranks || token_rank_dims[0] != num_tokens || + token_rank_dims[1] != runtime.num_ranks) { + return ffi::Error::InvalidArgument("DeepEP fused dispatch rank metadata shape mismatch"); + } + if (expert_dims[0] != num_experts) { + return ffi::Error::InvalidArgument("DeepEP fused dispatch expert metadata shape mismatch"); + } + if (hidden <= 0 || (hidden * static_cast(ffi::ByteWidth(ffi::BF16))) % sizeof(int4) != 0) { + return ffi::Error::InvalidArgument("DeepEP fused dispatch requires hidden*element_size divisible by int4"); + } + if (num_experts % runtime.num_ranks != 0) { + return ffi::Error::InvalidArgument("DeepEP fused dispatch requires num_experts divisible by num_ranks"); + } + const int num_local_experts = num_experts / runtime.num_ranks; + const int recv_capacity = static_cast(recv_x->dimensions()[0]); + const int total_assignments = recv_capacity * num_topk; + if (local_expert_counts->dimensions().size() != 1 || local_expert_counts->dimensions()[0] != num_local_experts) { + return ffi::Error::InvalidArgument("DeepEP fused dispatch local_expert_counts shape mismatch"); + } + if (num_recv_tokens_buffer->dimensions().size() != 1 || num_recv_tokens_buffer->dimensions()[0] != 1) { + return ffi::Error::InvalidArgument("DeepEP fused dispatch num_recv_tokens buffer must have shape [1]"); + } + + const int num_channels = runtime.dispatch_num_channels(); + if (rank_prefix_matrix->dimensions().size() != 2 || + rank_prefix_matrix->dimensions()[0] != runtime.num_ranks || + rank_prefix_matrix->dimensions()[1] != runtime.num_ranks || + channel_prefix_matrix->dimensions().size() != 2 || + channel_prefix_matrix->dimensions()[0] != runtime.num_ranks || + channel_prefix_matrix->dimensions()[1] != num_channels || + recv_channel_prefix_matrix->dimensions().size() != 2 || + recv_channel_prefix_matrix->dimensions()[0] != runtime.num_ranks || + recv_channel_prefix_matrix->dimensions()[1] != num_channels || + send_head->dimensions().size() != 2 || + send_head->dimensions()[0] != num_tokens || + send_head->dimensions()[1] != runtime.num_ranks) { + return ffi::Error::InvalidArgument("DeepEP fused dispatch handle tensor shapes are invalid"); + } + if (recv_x->dimensions().size() != 2 || recv_x->dimensions()[1] != hidden || + recv_src_idx->dimensions().size() != 1 || recv_src_idx->dimensions()[0] != recv_capacity || + recv_topk_weights->dimensions().size() != 2 || recv_topk_weights->dimensions()[0] != recv_capacity || + recv_topk_weights->dimensions()[1] != num_topk) { + return ffi::Error::InvalidArgument("DeepEP fused dispatch recv tensor shapes are invalid"); + } + if (topk_scratch_dims.size() != 2 || topk_scratch_dims[0] != num_tokens || + topk_scratch_dims[1] != num_topk * 2 || + recv_topk_scratch_dims.size() != 2 || recv_topk_scratch_dims[0] != recv_capacity || + recv_topk_scratch_dims[1] != num_topk * 2) { + return ffi::Error::InvalidArgument("DeepEP fused dispatch int64 scratch tensor shapes are invalid"); + } + + if (x_dispatch->dimensions().size() != 2 || x_dispatch->dimensions()[0] != total_assignments || + x_dispatch->dimensions()[1] != hidden || + assignment_weights->dimensions().size() != 1 || assignment_weights->dimensions()[0] != total_assignments || + recv_token_indices->dimensions().size() != 1 || recv_token_indices->dimensions()[0] != total_assignments || + local_group_cursors->dimensions().size() != 1 || + local_group_cursors->dimensions()[0] != num_local_experts || + recv_assignment_indices->dimensions().size() != 1 || + recv_assignment_indices->dimensions()[0] != total_assignments || + assignment_destinations->dimensions().size() != 1 || + assignment_destinations->dimensions()[0] != total_assignments) { + return ffi::Error::InvalidArgument("DeepEP fused assignment output shapes are invalid"); + } + + const size_t topk_count = static_cast(num_tokens) * num_topk; + auto* topk_idx_s64 = reinterpret_cast(topk_idx_s64_scratch->typed_data()); + auto* recv_topk_idx_s64 = reinterpret_cast(recv_topk_idx_s64_scratch->typed_data()); + LaunchCastInt32ToInt64(topk_idx.typed_data(), topk_idx_s64, topk_count, stream); + + int num_recv_tokens_host = recv_capacity; + DispatchAssignmentsOnCurrentDevice( + runtime, + stream, + reinterpret_cast(x.typed_data()), + topk_idx_s64, + topk_weights.typed_data(), + num_tokens_per_rank.typed_data(), + num_tokens_per_expert.typed_data(), + is_token_in_rank.typed_data(), + num_tokens, + hidden, + num_topk, + num_experts, + reinterpret_cast(recv_x->typed_data()), + recv_topk_idx_s64, + recv_topk_weights->typed_data(), + recv_src_idx->typed_data(), + rank_prefix_matrix->typed_data(), + channel_prefix_matrix->typed_data(), + recv_channel_prefix_matrix->typed_data(), + send_head->typed_data(), + local_expert_counts->typed_data(), + &num_recv_tokens_host, + num_recv_tokens_buffer->typed_data(), + reinterpret_cast(x_dispatch->typed_data()), + reinterpret_cast(assignment_weights->typed_data()), + recv_token_indices->typed_data(), + local_group_cursors->typed_data(), + recv_assignment_indices->typed_data(), + assignment_destinations->typed_data(), + recv_capacity); + return ffi::Error::Success(); + } catch (const std::exception& exc) { + return ffi::Error::Internal(exc.what()); + } +} + ffi::Error DispatchIntranodeCached( cudaStream_t stream, ffi::Buffer x, @@ -986,9 +1730,11 @@ ffi::Error CombineIntranode( } const int hidden = static_cast(recv_x_dims[1]); const int num_topk = static_cast(recv_topk_dims[1]); - const int num_recv_tokens = ReadDeviceScalarInt( + const int num_recv_tokens = ReadRecvCount( + runtime, stream, num_recv_tokens_buffer.typed_data(), + static_cast(recv_x_dims[0]), "cudaMemcpyAsync(read combine num_recv_tokens)"); const int combined_tokens = static_cast(send_head_dims[0]); if (recv_topk_dims[0] != recv_x_dims[0] || src_dims[0] != recv_x_dims[0]) { @@ -1065,6 +1811,241 @@ ffi::Error CombineIntranode( } } +ffi::Error PackLocalAssignments( + cudaStream_t stream, + ffi::Buffer recv_x, + ffi::Buffer recv_topk_idx, + ffi::Buffer recv_topk_weights, + ffi::Buffer num_recv_tokens, + int32_t local_experts, + ffi::Result> x_dispatch, + ffi::Result> assignment_weights, + ffi::Result> recv_token_indices, + ffi::Result> local_group_sizes, + ffi::Result> local_group_cursors, + ffi::Result> recv_assignment_indices, + ffi::Result> assignment_destinations) { + try { + DeviceRuntime& runtime = RuntimeManager::Instance().RuntimeForCurrentDevice(); + const auto recv_x_dims = recv_x.dimensions(); + const auto topk_dims = recv_topk_idx.dimensions(); + const auto weights_dims = recv_topk_weights.dimensions(); + const auto token_dims = num_recv_tokens.dimensions(); + const auto dispatch_dims = x_dispatch->dimensions(); + const auto assignment_weight_dims = assignment_weights->dimensions(); + const auto recv_token_dims = recv_token_indices->dimensions(); + const auto group_dims = local_group_sizes->dimensions(); + const auto cursor_dims = local_group_cursors->dimensions(); + const auto assignment_index_dims = recv_assignment_indices->dimensions(); + const auto assignment_destination_dims = assignment_destinations->dimensions(); + if (recv_x_dims.size() != 2 || topk_dims.size() != 2 || weights_dims.size() != 2 || + token_dims.size() != 1 || dispatch_dims.size() != 2 || assignment_weight_dims.size() != 1 || + recv_token_dims.size() != 1 || group_dims.size() != 1 || cursor_dims.size() != 1 || + assignment_index_dims.size() != 1 || assignment_destination_dims.size() != 1) { + return ffi::Error::InvalidArgument("DeepEP local assignment pack expects rank-1/2 tensors"); + } + if (token_dims[0] != 1) { + return ffi::Error::InvalidArgument("DeepEP local assignment pack expects num_recv_tokens shape [1]"); + } + if (local_experts <= 0) { + return ffi::Error::InvalidArgument("DeepEP local assignment pack requires positive local_experts"); + } + const int recv_capacity = static_cast(recv_x_dims[0]); + const int hidden = static_cast(recv_x_dims[1]); + const int num_topk = static_cast(topk_dims[1]); + const int total_assignments = recv_capacity * num_topk; + if (topk_dims[0] != recv_capacity || weights_dims[0] != recv_capacity || weights_dims[1] != num_topk) { + return ffi::Error::InvalidArgument("DeepEP local assignment pack recv top-k tensors must match recv_x"); + } + if (dispatch_dims[0] != total_assignments || dispatch_dims[1] != hidden || + assignment_weight_dims[0] != total_assignments || + recv_token_dims[0] != total_assignments || + group_dims[0] != local_experts || + cursor_dims[0] != local_experts || + assignment_index_dims[0] != total_assignments || + assignment_destination_dims[0] != total_assignments) { + return ffi::Error::InvalidArgument("DeepEP local assignment pack output shapes are invalid"); + } + if (hidden <= 0 || (hidden * static_cast(ffi::ByteWidth(ffi::BF16))) % sizeof(int4) != 0) { + return ffi::Error::InvalidArgument("DeepEP local assignment pack requires hidden*element_size divisible by int4"); + } + + LaunchPackLocalAssignments( + reinterpret_cast(recv_x.typed_data()), + recv_topk_idx.typed_data(), + recv_topk_weights.typed_data(), + num_recv_tokens.typed_data(), + reinterpret_cast(x_dispatch->typed_data()), + reinterpret_cast(assignment_weights->typed_data()), + recv_token_indices->typed_data(), + local_group_sizes->typed_data(), + local_group_cursors->typed_data(), + recv_assignment_indices->typed_data(), + assignment_destinations->typed_data(), + recv_capacity, + hidden, + num_topk, + local_experts, + stream); + return ffi::Error::Success(); + } catch (const std::exception& exc) { + return ffi::Error::Internal(exc.what()); + } +} + +ffi::Error PackLocalAssignmentsFromCounts( + cudaStream_t stream, + ffi::Buffer recv_x, + ffi::Buffer recv_topk_idx, + ffi::Buffer recv_topk_weights, + ffi::Buffer num_recv_tokens, + ffi::Buffer local_group_sizes, + ffi::Result> x_dispatch, + ffi::Result> assignment_weights, + ffi::Result> recv_token_indices, + ffi::Result> local_group_cursors, + ffi::Result> recv_assignment_indices, + ffi::Result> assignment_destinations) { + try { + DeviceRuntime& runtime = RuntimeManager::Instance().RuntimeForCurrentDevice(); + const auto recv_x_dims = recv_x.dimensions(); + const auto topk_dims = recv_topk_idx.dimensions(); + const auto weights_dims = recv_topk_weights.dimensions(); + const auto token_dims = num_recv_tokens.dimensions(); + const auto group_dims = local_group_sizes.dimensions(); + const auto dispatch_dims = x_dispatch->dimensions(); + const auto assignment_weight_dims = assignment_weights->dimensions(); + const auto recv_token_dims = recv_token_indices->dimensions(); + const auto cursor_dims = local_group_cursors->dimensions(); + const auto assignment_index_dims = recv_assignment_indices->dimensions(); + const auto assignment_destination_dims = assignment_destinations->dimensions(); + if (recv_x_dims.size() != 2 || topk_dims.size() != 2 || weights_dims.size() != 2 || + token_dims.size() != 1 || group_dims.size() != 1 || dispatch_dims.size() != 2 || + assignment_weight_dims.size() != 1 || recv_token_dims.size() != 1 || cursor_dims.size() != 1 || + assignment_index_dims.size() != 1 || assignment_destination_dims.size() != 1) { + return ffi::Error::InvalidArgument("DeepEP local assignment count-seeded pack expects rank-1/2 tensors"); + } + if (token_dims[0] != 1) { + return ffi::Error::InvalidArgument( + "DeepEP local assignment count-seeded pack expects num_recv_tokens shape [1]"); + } + const int recv_capacity = static_cast(recv_x_dims[0]); + const int hidden = static_cast(recv_x_dims[1]); + const int num_topk = static_cast(topk_dims[1]); + const int local_experts = static_cast(group_dims[0]); + const int total_assignments = recv_capacity * num_topk; + if (local_experts <= 0) { + return ffi::Error::InvalidArgument("DeepEP local assignment count-seeded pack needs local experts"); + } + if (topk_dims[0] != recv_capacity || weights_dims[0] != recv_capacity || weights_dims[1] != num_topk) { + return ffi::Error::InvalidArgument( + "DeepEP local assignment count-seeded pack recv top-k tensors must match recv_x"); + } + if (dispatch_dims[0] != total_assignments || dispatch_dims[1] != hidden || + assignment_weight_dims[0] != total_assignments || + recv_token_dims[0] != total_assignments || + cursor_dims[0] != local_experts || + assignment_index_dims[0] != total_assignments || + assignment_destination_dims[0] != total_assignments) { + return ffi::Error::InvalidArgument("DeepEP local assignment count-seeded pack output shapes are invalid"); + } + if (hidden <= 0 || (hidden * static_cast(ffi::ByteWidth(ffi::BF16))) % sizeof(int4) != 0) { + return ffi::Error::InvalidArgument( + "DeepEP local assignment count-seeded pack requires hidden*element_size divisible by int4"); + } + + LaunchPackLocalAssignmentsFromCounts( + reinterpret_cast(recv_x.typed_data()), + recv_topk_idx.typed_data(), + recv_topk_weights.typed_data(), + num_recv_tokens.typed_data(), + local_group_sizes.typed_data(), + reinterpret_cast(x_dispatch->typed_data()), + reinterpret_cast(assignment_weights->typed_data()), + recv_token_indices->typed_data(), + local_group_cursors->typed_data(), + recv_assignment_indices->typed_data(), + assignment_destinations->typed_data(), + recv_capacity, + hidden, + num_topk, + local_experts, + recv_capacity, + total_assignments, + stream); + return ffi::Error::Success(); + } catch (const std::exception& exc) { + return ffi::Error::Internal(exc.what()); + } +} + +ffi::Error CollapseLocalAssignments( + cudaStream_t stream, + ffi::Buffer out_dispatch, + ffi::Buffer assignment_weights, + ffi::Buffer assignment_destinations, + ffi::Buffer accepted_total_assignments, + ffi::Buffer num_recv_tokens, + ffi::Result> recv_out) { + try { + DeviceRuntime& runtime = RuntimeManager::Instance().RuntimeForCurrentDevice(); + const auto dispatch_dims = out_dispatch.dimensions(); + const auto weight_dims = assignment_weights.dimensions(); + const auto destination_dims = assignment_destinations.dimensions(); + const auto accepted_total_dims = accepted_total_assignments.dimensions(); + const auto token_dims = num_recv_tokens.dimensions(); + const auto out_dims = recv_out->dimensions(); + if (dispatch_dims.size() != 2 || weight_dims.size() != 1 || destination_dims.size() != 1 || + accepted_total_dims.size() != 1 || token_dims.size() != 1 || out_dims.size() != 2) { + return ffi::Error::InvalidArgument("DeepEP local assignment collapse expects rank-1/2 tensors"); + } + if (accepted_total_dims[0] != 1) { + return ffi::Error::InvalidArgument("DeepEP local assignment collapse expects accepted total shape [1]"); + } + if (token_dims[0] != 1) { + return ffi::Error::InvalidArgument("DeepEP local assignment collapse expects num_recv_tokens shape [1]"); + } + const int total_assignments = static_cast(dispatch_dims[0]); + const int hidden = static_cast(dispatch_dims[1]); + const int recv_capacity = static_cast(out_dims[0]); + if (weight_dims[0] != total_assignments) { + return ffi::Error::InvalidArgument("DeepEP local assignment collapse metadata must match out_dispatch"); + } + if (recv_capacity <= 0 || destination_dims[0] % recv_capacity != 0) { + return ffi::Error::InvalidArgument("DeepEP local assignment collapse destination map shape is invalid"); + } + if (out_dims[1] != hidden) { + return ffi::Error::InvalidArgument("DeepEP local assignment collapse output shapes are invalid"); + } + const int num_topk = static_cast(destination_dims[0]) / recv_capacity; + const int active_recv_tokens = ReadRecvCount( + runtime, + stream, + num_recv_tokens.typed_data(), + recv_capacity, + "cudaMemcpyAsync(read collapse num_recv_tokens)"); + if (active_recv_tokens < 0 || active_recv_tokens > recv_capacity) { + return ffi::Error::InvalidArgument("DeepEP local assignment collapse num_recv_tokens is out of range"); + } + + LaunchCollapseLocalAssignments( + reinterpret_cast(out_dispatch.typed_data()), + reinterpret_cast(assignment_weights.typed_data()), + assignment_destinations.typed_data(), + accepted_total_assignments.typed_data(), + num_recv_tokens.typed_data(), + reinterpret_cast(recv_out->typed_data()), + recv_capacity, + active_recv_tokens, + num_topk, + hidden, + stream); + return ffi::Error::Success(); + } catch (const std::exception& exc) { + return ffi::Error::Internal(exc.what()); + } +} + auto DispatchBinding() { return ffi::Ffi::Bind() .Ctx>() @@ -1089,6 +2070,35 @@ auto DispatchBinding() { .Ret>(); } +auto DispatchWithAssignmentsBinding() { + return ffi::Ffi::Bind() + .Ctx>() + .Arg>() + .Arg>() + .Arg>() + .Arg>() + .Arg>() + .Arg>() + .Attr("num_experts") + .Ret>() + .Ret>() + .Ret>() + .Ret>() + .Ret>() + .Ret>() + .Ret>() + .Ret>() + .Ret>() + .Ret>() + .Ret>() + .Ret>() + .Ret>() + .Ret>() + .Ret>() + .Ret>() + .Ret>(); +} + auto DispatchCachedBinding() { return ffi::Ffi::Bind() .Ctx>() @@ -1118,6 +2128,50 @@ auto CombineBinding() { .Ret>(); } +auto PackLocalAssignmentsBinding() { + return ffi::Ffi::Bind() + .Ctx>() + .Arg>() + .Arg>() + .Arg>() + .Arg>() + .Attr("local_experts") + .Ret>() + .Ret>() + .Ret>() + .Ret>() + .Ret>() + .Ret>() + .Ret>(); +} + +auto PackLocalAssignmentsFromCountsBinding() { + return ffi::Ffi::Bind() + .Ctx>() + .Arg>() + .Arg>() + .Arg>() + .Arg>() + .Arg>() + .Ret>() + .Ret>() + .Ret>() + .Ret>() + .Ret>() + .Ret>(); +} + +auto CollapseLocalAssignmentsBinding() { + return ffi::Ffi::Bind() + .Ctx>() + .Arg>() + .Arg>() + .Arg>() + .Arg>() + .Arg>() + .Ret>(); +} + } // namespace extern "C" int levanter_deepep_init_intranode_runtime( @@ -1497,6 +2551,11 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( DispatchIntranode, DispatchBinding()); +XLA_FFI_DEFINE_HANDLER_SYMBOL( + levanter_deepep_dispatch_intranode_with_assignments, + DispatchIntranodeWithAssignments, + DispatchWithAssignmentsBinding()); + XLA_FFI_DEFINE_HANDLER_SYMBOL( levanter_deepep_dispatch_intranode_cached, DispatchIntranodeCached, @@ -1506,3 +2565,18 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( levanter_deepep_combine_intranode, CombineIntranode, CombineBinding()); + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + levanter_deepep_pack_local_assignments, + PackLocalAssignments, + PackLocalAssignmentsBinding()); + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + levanter_deepep_pack_local_assignments_from_counts, + PackLocalAssignmentsFromCounts, + PackLocalAssignmentsFromCountsBinding()); + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + levanter_deepep_collapse_local_assignments, + CollapseLocalAssignments, + CollapseLocalAssignmentsBinding()); diff --git a/lib/levanter/src/levanter/kernels/deepep/transport_ffi.py b/lib/levanter/src/levanter/kernels/deepep/transport_ffi.py index 740c2d6a44..f61d17b1de 100644 --- a/lib/levanter/src/levanter/kernels/deepep/transport_ffi.py +++ b/lib/levanter/src/levanter/kernels/deepep/transport_ffi.py @@ -18,6 +18,7 @@ from dataclasses import dataclass from functools import partial from pathlib import Path +from typing import NamedTuple import jax import jax.numpy as jnp @@ -38,8 +39,12 @@ ) _DISPATCH_TARGET = "levanter_deepep_dispatch_intranode" +_DISPATCH_WITH_ASSIGNMENTS_TARGET = "levanter_deepep_dispatch_intranode_with_assignments" _DISPATCH_CACHED_TARGET = "levanter_deepep_dispatch_intranode_cached" _COMBINE_TARGET = "levanter_deepep_combine_intranode" +_PACK_LOCAL_ASSIGNMENTS_TARGET = "levanter_deepep_pack_local_assignments" +_PACK_LOCAL_ASSIGNMENTS_FROM_COUNTS_TARGET = "levanter_deepep_pack_local_assignments_from_counts" +_COLLAPSE_LOCAL_ASSIGNMENTS_TARGET = "levanter_deepep_collapse_local_assignments" _INIT_SYMBOL = "levanter_deepep_init_intranode_runtime" _SHUTDOWN_SYMBOL = "levanter_deepep_shutdown_intranode_runtime" _LAST_ERROR_SYMBOL = "levanter_deepep_last_error" @@ -48,7 +53,7 @@ _EXTENDED_INTRNODE_DISPATCH_MACRO = "LEVANTER_DEEPEP_EXTENDED_INTRNODE_DISPATCH" _PYEXT_MODULE_NAME_MACRO = "LEVANTER_DEEPEP_PYEXT_MODULE_NAME" _DISPATCH_THREADS_ENV = "DEEPEP_DISPATCH_NUM_THREADS" -_BUILD_CACHE_SCHEMA_VERSION = "transport_ffi_raw_dlink_v18" +_BUILD_CACHE_SCHEMA_VERSION = "transport_ffi_raw_dlink_v27" _LIBRARY_DLOPEN_MODE = getattr(os, "RTLD_NOW", 0) | getattr(ctypes, "RTLD_GLOBAL", 0) _SM100_TMA_DISPATCH_THREADS = 512 _UPSTREAM_DISPATCH_THREADS = 768 @@ -67,15 +72,44 @@ class BuildArtifact: module_name: str | None +class DeepEPDispatch(NamedTuple): + recv_x: jax.Array + recv_topk_idx: jax.Array + recv_topk_weights: jax.Array + recv_src_idx: jax.Array + rank_prefix_matrix: jax.Array + channel_prefix_matrix: jax.Array + recv_channel_prefix_matrix: jax.Array + send_head: jax.Array + local_expert_counts: jax.Array + num_recv_tokens: jax.Array + + +class DeepEPDispatchWithAssignments(NamedTuple): + recv_x: jax.Array + recv_topk_weights: jax.Array + recv_src_idx: jax.Array + rank_prefix_matrix: jax.Array + channel_prefix_matrix: jax.Array + recv_channel_prefix_matrix: jax.Array + send_head: jax.Array + local_group_sizes: jax.Array + num_recv_tokens: jax.Array + x_dispatch: jax.Array + assignment_weights: jax.Array + recv_token_indices: jax.Array + assignment_destinations: jax.Array + + _DEFAULT_DISPATCH_CONFIGS = { 2: IntranodeConfig(num_sms=20, num_max_send_tokens=24, num_max_recv_tokens=256), - 4: IntranodeConfig(num_sms=20, num_max_send_tokens=6, num_max_recv_tokens=256), + 4: IntranodeConfig(num_sms=120, num_max_send_tokens=12, num_max_recv_tokens=256), 8: IntranodeConfig(num_sms=20, num_max_send_tokens=6, num_max_recv_tokens=256), } _DEFAULT_COMBINE_CONFIGS = { 2: IntranodeConfig(num_sms=20, num_max_send_tokens=10, num_max_recv_tokens=256), - 4: IntranodeConfig(num_sms=20, num_max_send_tokens=9, num_max_recv_tokens=256), + 4: IntranodeConfig(num_sms=120, num_max_send_tokens=18, num_max_recv_tokens=256), 8: IntranodeConfig(num_sms=20, num_max_send_tokens=4, num_max_recv_tokens=256), } @@ -187,6 +221,8 @@ def _dispatch_thread_override() -> int | None: def _intranode_source_bytes(deepep_root: Path) -> bytes: source = _intranode_source(deepep_root) text = source.read_text() + if "#include " not in text: + text = "#include \n" + text dispatch_threads = _dispatch_thread_override() if dispatch_threads is not None and dispatch_threads != _UPSTREAM_DISPATCH_THREADS: dispatch_start = text.find("\nvoid dispatch(") @@ -220,13 +256,135 @@ def _intranode_source_bytes(deepep_root: Path) -> bytes: raise RuntimeError("Could not patch DeepEP intranode TMA launch pattern for this source tree") text = text.replace(old, new, 1) + assignment_dispatch_threads = dispatch_threads or _UPSTREAM_DISPATCH_THREADS + text = _add_assignment_dispatch_source(text, dispatch_threads=assignment_dispatch_threads) return text.encode("utf-8") +def _add_assignment_dispatch_source(text: str, *, dispatch_threads: int = _UPSTREAM_DISPATCH_THREADS) -> str: + kernel_anchor = "__global__ void __launch_bounds__(kNumThreads, 1)\ndispatch(" + kernel_anchor_start = text.find(kernel_anchor) + kernel_start = text.rfind("\ntemplate", 0, kernel_anchor_start) + 1 + host_start = text.find("\nvoid dispatch(", kernel_start) + if kernel_anchor_start < 0 or kernel_start <= 0 or host_start < 0: + raise RuntimeError("Could not find DeepEP intranode dispatch kernel for assignment-native patch") + + assignment_kernel = text[kernel_start:host_start] + assignment_kernel = assignment_kernel.replace( + "dispatch(int4* recv_x, float* recv_x_scales, float* recv_x_sf_scale_for_nvfp4, int* recv_src_idx, " + "int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset,", + "dispatch_assignments(int4* recv_x, float* recv_x_scales, float* recv_x_sf_scale_for_nvfp4, " + "int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset,\n" + " int4* x_dispatch, nv_bfloat16* assignment_weights, int* recv_token_indices,\n" + " int* local_group_cursors, int* recv_assignment_indices, int* assignment_destinations,", + 1, + ) + + receiver_anchor = " // Workers for receiving and copying into buffer\n" + receiver_start = assignment_kernel.find(receiver_anchor) + receive_start = assignment_kernel.find(" // Copy data\n", receiver_start) + receive_end = assignment_kernel.find(" // Copy `x_scales`\n", receive_start) + if receiver_start < 0 or receive_start < 0 or receive_end < 0: + raise RuntimeError("Could not find DeepEP intranode receive loop for assignment-native patch") + assignment_receive = """ // Copy queue payloads directly into local-expert assignment order. + int num_recv_tokens = cached_channel_tail_idx - cached_channel_head_idx; + for (int chunk_idx = recv_warp_id_in_rank; chunk_idx < num_recv_tokens; chunk_idx += num_recv_warps_per_rank) { + int token_idx_in_buffer = (cached_channel_head_idx + chunk_idx) % num_recv_buffer_tokens; + int recv_token_idx = total_offset + chunk_idx; + + if (lane_id == 0) + recv_src_idx[recv_token_idx] = ld_nc_global(channel_src_idx_buffers.buffer() + token_idx_in_buffer); + + #pragma unroll + for (int token_topk_idx = 0; token_topk_idx < num_topk; ++ token_topk_idx) { + auto buffer_idx = token_idx_in_buffer * num_topk + token_topk_idx; + auto recv_idx = static_cast(recv_token_idx) * num_topk + token_topk_idx; + int local_expert = ld_nc_global(channel_topk_idx_buffers.buffer() + buffer_idx); + float weight = ld_nc_global(channel_topk_weights_buffers.buffer() + buffer_idx); + if (lane_id == token_topk_idx) { + recv_topk_idx[recv_idx] = local_expert; + recv_topk_weights[recv_idx] = weight; + } + if (local_expert < 0) + continue; + + int destination = -1; + if (lane_id == 0) { + destination = atomicAdd(local_group_cursors + local_expert, 1); + int assignment_idx = recv_token_idx * num_topk + token_topk_idx; + recv_token_indices[destination] = recv_token_idx; + recv_assignment_indices[destination] = assignment_idx; + assignment_destinations[assignment_idx] = destination; + assignment_weights[destination] = __float2bfloat16(weight); + } + destination = __shfl_sync(0xffffffff, destination, 0); + auto shifted_buffer_x_int4 = channel_x_buffers.buffer() + token_idx_in_buffer * hidden_int4; + auto shifted_x_dispatch_int4 = x_dispatch + static_cast(destination) * hidden_int4; + UNROLLED_WARP_COPY(5, lane_id, hidden_int4, shifted_x_dispatch_int4, shifted_buffer_x_int4, + ld_nc_global, st_na_global); + } + } + +""" + assignment_kernel = assignment_kernel[:receive_start] + assignment_receive + assignment_kernel[receive_end:] + + combine_start = text.find( + "\ntemplate", host_start + ) + if combine_start < 0: + raise RuntimeError("Could not find DeepEP intranode combine kernel insertion point") + assignment_host = """ +void dispatch_assignments(void* recv_x, float* recv_x_scales, float* recv_x_sf_scale_for_nvfp4, + int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, + int* recv_channel_offset, void* x_dispatch, nv_bfloat16* assignment_weights, + int* recv_token_indices, int* local_group_cursors, + int* recv_assignment_indices, int* assignment_destinations, int* send_head, + const void* x, const float* x_scales, const float* sf_scale_for_nvfp4, + const int64_t* topk_idx, const float* topk_weights, + const bool* is_token_in_rank, const int* channel_prefix_matrix, + int num_tokens, int num_worst_tokens, int hidden_int4, int num_topk, int num_experts, + int num_scales, int num_sf_scales_for_nvfp4, int scale_token_stride, + int scale_hidden_stride, int sf_scale_for_nvfp4_token_stride, + int sf_scale_for_nvfp4_hidden_stride, void** buffer_ptrs, int rank, int num_ranks, + cudaStream_t stream, int num_sms, int num_max_send_tokens, + int num_recv_buffer_tokens) { + constexpr int kNumThreads = __DISPATCH_THREADS__; + constexpr int kNumTMABytesPerWarp = 8192; +#ifndef DISABLE_SM90_FEATURES + constexpr int smem_size = kNumTMABytesPerWarp * (kNumThreads / 32); +#endif + + EP_HOST_ASSERT(static_cast(num_scales) * scale_hidden_stride < std::numeric_limits::max()); + EP_HOST_ASSERT(static_cast(num_sf_scales_for_nvfp4) * sf_scale_for_nvfp4_hidden_stride < std::numeric_limits::max()); + +#define DISPATCH_ASSIGNMENTS_LAUNCH_CASE(ranks) { \\ + SET_SHARED_MEMORY_FOR_TMA((dispatch_assignments)); \\ + auto kernel = dispatch_assignments; \\ + LAUNCH_KERNEL(&cfg, kernel, \\ + reinterpret_cast(recv_x), recv_x_scales, recv_x_sf_scale_for_nvfp4, recv_src_idx, recv_topk_idx, recv_topk_weights, recv_channel_offset, \\ + reinterpret_cast(x_dispatch), assignment_weights, recv_token_indices, local_group_cursors, recv_assignment_indices, assignment_destinations, \\ + send_head, reinterpret_cast(x), x_scales, sf_scale_for_nvfp4, topk_idx, topk_weights, \\ + is_token_in_rank, channel_prefix_matrix, \\ + num_tokens, num_worst_tokens, hidden_int4, num_topk, num_experts, num_scales, num_sf_scales_for_nvfp4, \\ + scale_token_stride, scale_hidden_stride, sf_scale_for_nvfp4_token_stride, sf_scale_for_nvfp4_hidden_stride, \\ + buffer_ptrs, rank, \\ + num_max_send_tokens, num_recv_buffer_tokens); \\ + } \\ + break + + EP_HOST_ASSERT(num_sms % 2 == 0); + SETUP_LAUNCH_CONFIG(num_sms, kNumThreads, stream); + SWITCH_RANKS(DISPATCH_ASSIGNMENTS_LAUNCH_CASE); +#undef DISPATCH_ASSIGNMENTS_LAUNCH_CASE +} + +""".replace( + "__DISPATCH_THREADS__", str(dispatch_threads) + ) + return text[:combine_start] + "\n" + assignment_kernel + "\n" + assignment_host + text[combine_start:] + + def _prepare_intranode_source(build_dir: Path, deepep_root: Path) -> Path: - dispatch_threads = _dispatch_thread_override() - if dispatch_threads is None or dispatch_threads == _UPSTREAM_DISPATCH_THREADS: - return _intranode_source(deepep_root) patched_source = build_dir / "generated" / "intranode.cu" patched_source.parent.mkdir(parents=True, exist_ok=True) patched_source.write_bytes(_intranode_source_bytes(deepep_root)) @@ -658,7 +816,15 @@ def _register_targets() -> None: if getattr(_register_targets, "_done", False): return library = _load_library() - for target in (_DISPATCH_TARGET, _DISPATCH_CACHED_TARGET, _COMBINE_TARGET): + for target in ( + _DISPATCH_TARGET, + _DISPATCH_WITH_ASSIGNMENTS_TARGET, + _DISPATCH_CACHED_TARGET, + _COMBINE_TARGET, + _PACK_LOCAL_ASSIGNMENTS_TARGET, + _PACK_LOCAL_ASSIGNMENTS_FROM_COUNTS_TARGET, + _COLLAPSE_LOCAL_ASSIGNMENTS_TARGET, + ): handler = getattr(library, target) handler.restype = ctypes.c_void_p jax.ffi.register_ffi_target( @@ -916,6 +1082,125 @@ def _dispatch_intranode_impl( ) +def _dispatch_intranode_with_assignments_impl( + x: jax.Array, + topk_idx: jax.Array, + topk_weights: jax.Array, + num_tokens_per_rank: jax.Array, + num_tokens_per_expert: jax.Array, + is_token_in_rank: jax.Array, + *, + num_experts: int, + dispatch_config: IntranodeConfig | None, + combine_config: IntranodeConfig | None, + max_recv_tokens: int | None, +) -> tuple[ + jax.Array, + jax.Array, + jax.Array, + jax.Array, + jax.Array, + jax.Array, + jax.Array, + jax.Array, + jax.Array, + jax.Array, + jax.Array, + jax.Array, + jax.Array, + jax.Array, +]: + _register_targets() + num_ranks = int(num_tokens_per_rank.shape[0]) + resolved_dispatch_config = _resolve_runtime( + x=x, + num_ranks=num_ranks, + dispatch_config=dispatch_config, + combine_config=combine_config, + ) + + x_bf16 = jnp.asarray(x, dtype=jnp.bfloat16) + topk_idx_i32 = jnp.asarray(topk_idx, dtype=jnp.int32) + topk_weights_f32 = jnp.asarray(topk_weights, dtype=jnp.float32) + num_tokens_per_rank_i32 = jnp.asarray(num_tokens_per_rank, dtype=jnp.int32) + num_tokens_per_expert_i32 = jnp.asarray(num_tokens_per_expert, dtype=jnp.int32) + local_experts = num_experts // num_ranks + if max_recv_tokens is None: + max_recv_tokens = x_bf16.shape[0] * num_ranks + elif max_recv_tokens <= 0: + raise ValueError(f"max_recv_tokens must be positive, got {max_recv_tokens}") + num_channels = resolved_dispatch_config.num_sms // 2 + topk = topk_idx_i32.shape[1] + max_assignments = max_recv_tokens * topk + result_shape_dtypes = ( + jax.ShapeDtypeStruct((max_recv_tokens, x_bf16.shape[1]), x_bf16.dtype), + jax.ShapeDtypeStruct((max_recv_tokens, topk), jnp.float32), + jax.ShapeDtypeStruct((max_recv_tokens,), jnp.int32), + jax.ShapeDtypeStruct((num_ranks, num_ranks), jnp.int32), + jax.ShapeDtypeStruct((num_ranks, num_channels), jnp.int32), + jax.ShapeDtypeStruct((num_ranks, num_channels), jnp.int32), + jax.ShapeDtypeStruct((x_bf16.shape[0], num_ranks), jnp.int32), + jax.ShapeDtypeStruct((local_experts,), jnp.int32), + jax.ShapeDtypeStruct((1,), jnp.int32), + jax.ShapeDtypeStruct((x_bf16.shape[0], topk * 2), jnp.int32), + jax.ShapeDtypeStruct((max_recv_tokens, topk * 2), jnp.int32), + jax.ShapeDtypeStruct((max_assignments, x_bf16.shape[1]), x_bf16.dtype), + jax.ShapeDtypeStruct((max_assignments,), x_bf16.dtype), + jax.ShapeDtypeStruct((max_assignments,), jnp.int32), + jax.ShapeDtypeStruct((local_experts,), jnp.int32), + jax.ShapeDtypeStruct((max_assignments,), jnp.int32), + jax.ShapeDtypeStruct((max_assignments,), jnp.int32), + ) + ( + recv_x, + recv_topk_weights, + recv_src_idx, + rank_prefix_matrix, + channel_prefix_matrix, + recv_channel_prefix_matrix, + send_head, + local_group_sizes, + num_recv_tokens, + _topk_idx_s64_scratch, + _recv_topk_idx_s64_scratch, + x_dispatch, + assignment_weights, + recv_token_indices, + _local_group_cursors, + recv_assignment_indices, + assignment_destinations, + ) = jax.ffi.ffi_call( + _DISPATCH_WITH_ASSIGNMENTS_TARGET, + result_shape_dtypes, + has_side_effect=True, + vmap_method="broadcast_all", + )( + x_bf16, + topk_idx_i32, + topk_weights_f32, + num_tokens_per_rank_i32, + num_tokens_per_expert_i32, + is_token_in_rank, + num_experts=np.int32(num_experts), + ) + return ( + recv_x, + recv_topk_weights, + recv_src_idx, + rank_prefix_matrix, + channel_prefix_matrix, + recv_channel_prefix_matrix, + send_head, + local_group_sizes, + num_recv_tokens, + x_dispatch, + assignment_weights, + recv_token_indices, + recv_assignment_indices, + assignment_destinations, + ) + + def _dispatch_intranode_cached_impl( x: jax.Array, is_token_in_rank: jax.Array, @@ -1007,6 +1292,419 @@ def _combine_intranode_impl( return combined_x, combined_topk_weights +def deepep_pack_local_assignments( + recv_x: jax.Array, + recv_topk_idx: jax.Array, + recv_topk_weights: jax.Array, + num_recv_tokens: jax.Array, + *, + local_experts: int, +) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]: + x_dispatch, assignment_weights, recv_token_indices, local_group_sizes, _ = _pack_local_assignments_with_vjp( + recv_x, + recv_topk_idx, + recv_topk_weights, + num_recv_tokens, + local_experts, + ) + return x_dispatch, assignment_weights, recv_token_indices, local_group_sizes + + +def deepep_pack_local_assignments_from_counts( + recv_x: jax.Array, + recv_topk_idx: jax.Array, + recv_topk_weights: jax.Array, + num_recv_tokens: jax.Array, + local_group_sizes: jax.Array, +) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]: + x_dispatch, assignment_weights, recv_token_indices, recv_assignment_indices, assignment_destinations = ( + _pack_local_assignments_from_counts_with_vjp( + recv_x, + recv_topk_idx, + recv_topk_weights, + num_recv_tokens, + local_group_sizes, + ) + ) + del recv_assignment_indices + return x_dispatch, assignment_weights, recv_token_indices, local_group_sizes, assignment_destinations + + +def _pack_local_assignments_impl( + recv_x: jax.Array, + recv_topk_idx: jax.Array, + recv_topk_weights: jax.Array, + num_recv_tokens: jax.Array, + *, + local_experts: int, +) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]: + _register_targets() + if local_experts <= 0: + raise ValueError(f"local_experts must be positive, got {local_experts}") + recv_x_bf16 = jnp.asarray(recv_x, dtype=jnp.bfloat16) + recv_topk_idx_i32 = jnp.asarray(recv_topk_idx, dtype=jnp.int32) + recv_topk_weights_f32 = jnp.asarray(recv_topk_weights, dtype=jnp.float32) + num_recv_tokens_i32 = jnp.asarray(num_recv_tokens, dtype=jnp.int32) + if num_recv_tokens_i32.ndim == 0: + num_recv_tokens_i32 = jnp.reshape(num_recv_tokens_i32, (1,)) + + recv_capacity, hidden = recv_x_bf16.shape + topk = recv_topk_idx_i32.shape[1] + max_assignments = recv_capacity * topk + result_shape_dtypes = ( + jax.ShapeDtypeStruct((max_assignments, hidden), recv_x_bf16.dtype), + jax.ShapeDtypeStruct((max_assignments,), recv_x_bf16.dtype), + jax.ShapeDtypeStruct((max_assignments,), jnp.int32), + jax.ShapeDtypeStruct((local_experts,), jnp.int32), + jax.ShapeDtypeStruct((local_experts,), jnp.int32), + jax.ShapeDtypeStruct((max_assignments,), jnp.int32), + jax.ShapeDtypeStruct((max_assignments,), jnp.int32), + ) + ( + x_dispatch, + assignment_weights, + recv_token_indices, + local_group_sizes, + _, + recv_assignment_indices, + _, + ) = jax.ffi.ffi_call( + _PACK_LOCAL_ASSIGNMENTS_TARGET, + result_shape_dtypes, + has_side_effect=True, + vmap_method="broadcast_all", + )( + recv_x_bf16, + recv_topk_idx_i32, + recv_topk_weights_f32, + num_recv_tokens_i32, + local_experts=np.int32(local_experts), + ) + return x_dispatch, assignment_weights, recv_token_indices, local_group_sizes, recv_assignment_indices + + +def _pack_local_assignments_from_counts_impl( + recv_x: jax.Array, + recv_topk_idx: jax.Array, + recv_topk_weights: jax.Array, + num_recv_tokens: jax.Array, + local_group_sizes: jax.Array, +) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]: + _register_targets() + recv_x_bf16 = jnp.asarray(recv_x, dtype=jnp.bfloat16) + recv_topk_idx_i32 = jnp.asarray(recv_topk_idx, dtype=jnp.int32) + recv_topk_weights_f32 = jnp.asarray(recv_topk_weights, dtype=jnp.float32) + num_recv_tokens_i32 = jnp.asarray(num_recv_tokens, dtype=jnp.int32) + local_group_sizes_i32 = jnp.asarray(local_group_sizes, dtype=jnp.int32) + if num_recv_tokens_i32.ndim == 0: + num_recv_tokens_i32 = jnp.reshape(num_recv_tokens_i32, (1,)) + + recv_capacity, hidden = recv_x_bf16.shape + topk = recv_topk_idx_i32.shape[1] + max_assignments = recv_capacity * topk + result_shape_dtypes = ( + jax.ShapeDtypeStruct((max_assignments, hidden), recv_x_bf16.dtype), + jax.ShapeDtypeStruct((max_assignments,), recv_x_bf16.dtype), + jax.ShapeDtypeStruct((max_assignments,), jnp.int32), + jax.ShapeDtypeStruct((local_group_sizes_i32.shape[0],), jnp.int32), + jax.ShapeDtypeStruct((max_assignments,), jnp.int32), + jax.ShapeDtypeStruct((max_assignments,), jnp.int32), + ) + ( + x_dispatch, + assignment_weights, + recv_token_indices, + _, + recv_assignment_indices, + assignment_destinations, + ) = jax.ffi.ffi_call( + _PACK_LOCAL_ASSIGNMENTS_FROM_COUNTS_TARGET, + result_shape_dtypes, + has_side_effect=True, + vmap_method="broadcast_all", + )( + recv_x_bf16, + recv_topk_idx_i32, + recv_topk_weights_f32, + num_recv_tokens_i32, + local_group_sizes_i32, + ) + return x_dispatch, assignment_weights, recv_token_indices, recv_assignment_indices, assignment_destinations + + +@partial(jax.custom_vjp, nondiff_argnums=(4,)) +def _pack_local_assignments_with_vjp( + recv_x: jax.Array, + recv_topk_idx: jax.Array, + recv_topk_weights: jax.Array, + num_recv_tokens: jax.Array, + local_experts: int, +) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]: + return _pack_local_assignments_impl( + recv_x, + recv_topk_idx, + recv_topk_weights, + num_recv_tokens, + local_experts=local_experts, + ) + + +def _pack_local_assignments_with_vjp_fwd( + recv_x: jax.Array, + recv_topk_idx: jax.Array, + recv_topk_weights: jax.Array, + num_recv_tokens: jax.Array, + local_experts: int, +): + outputs = _pack_local_assignments_impl( + recv_x, + recv_topk_idx, + recv_topk_weights, + num_recv_tokens, + local_experts=local_experts, + ) + residuals = (outputs[2], outputs[3], outputs[4], recv_x.shape, recv_topk_weights.shape) + return outputs, residuals + + +def _pack_assignment_gradients( + *, + recv_token_indices: jax.Array, + recv_assignment_indices: jax.Array, + local_group_sizes: jax.Array, + recv_x_shape: tuple[int, ...], + recv_topk_weights_shape: tuple[int, ...], + cotangents, +) -> tuple[jax.Array, jax.Array]: + valid_assignments = jnp.arange(recv_assignment_indices.shape[0], dtype=jnp.int32) < jnp.sum(local_group_sizes) + safe_recv_token_indices = jnp.where(valid_assignments, recv_token_indices, 0) + safe_recv_assignment_indices = jnp.where(valid_assignments, recv_assignment_indices, 0) + grad_x_dispatch = _materialize_cotangent( + cotangents[0], + dtype=jnp.bfloat16, + shape=(recv_assignment_indices.shape[0], recv_x_shape[1]), + ) + grad_x_dispatch = jnp.where(valid_assignments[:, None], grad_x_dispatch, 0) + grad_assignment_weights = _materialize_cotangent( + cotangents[1], + dtype=jnp.float32, + shape=(recv_assignment_indices.shape[0],), + ) + grad_assignment_weights = jnp.where(valid_assignments, grad_assignment_weights, 0) + grad_recv_x = jax.ops.segment_sum( + grad_x_dispatch, + safe_recv_token_indices, + num_segments=recv_x_shape[0], + indices_are_sorted=False, + ) + grad_recv_topk_weights = jax.ops.segment_sum( + grad_assignment_weights.astype(jnp.float32), + safe_recv_assignment_indices, + num_segments=recv_topk_weights_shape[0] * recv_topk_weights_shape[1], + indices_are_sorted=False, + ).reshape(recv_topk_weights_shape) + return grad_recv_x, grad_recv_topk_weights + + +def _pack_local_assignments_with_vjp_bwd(local_experts: int, residuals, cotangents): + del local_experts + recv_token_indices, local_group_sizes, recv_assignment_indices, recv_x_shape, recv_topk_weights_shape = residuals + grad_recv_x, grad_recv_topk_weights = _pack_assignment_gradients( + recv_token_indices=recv_token_indices, + recv_assignment_indices=recv_assignment_indices, + local_group_sizes=local_group_sizes, + recv_x_shape=recv_x_shape, + recv_topk_weights_shape=recv_topk_weights_shape, + cotangents=cotangents, + ) + return grad_recv_x, None, grad_recv_topk_weights, None + + +_pack_local_assignments_with_vjp.defvjp( + _pack_local_assignments_with_vjp_fwd, + _pack_local_assignments_with_vjp_bwd, +) + + +@jax.custom_vjp +def _pack_local_assignments_from_counts_with_vjp( + recv_x: jax.Array, + recv_topk_idx: jax.Array, + recv_topk_weights: jax.Array, + num_recv_tokens: jax.Array, + local_group_sizes: jax.Array, +) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]: + return _pack_local_assignments_from_counts_impl( + recv_x, + recv_topk_idx, + recv_topk_weights, + num_recv_tokens, + local_group_sizes, + ) + + +def _pack_local_assignments_from_counts_with_vjp_fwd( + recv_x: jax.Array, + recv_topk_idx: jax.Array, + recv_topk_weights: jax.Array, + num_recv_tokens: jax.Array, + local_group_sizes: jax.Array, +): + outputs = _pack_local_assignments_from_counts_impl( + recv_x, + recv_topk_idx, + recv_topk_weights, + num_recv_tokens, + local_group_sizes, + ) + residuals = (outputs[2], outputs[3], local_group_sizes, recv_x.shape, recv_topk_weights.shape) + return outputs, residuals + + +def _pack_local_assignments_from_counts_with_vjp_bwd(residuals, cotangents): + recv_token_indices, recv_assignment_indices, local_group_sizes, recv_x_shape, recv_topk_weights_shape = residuals + grad_recv_x, grad_recv_topk_weights = _pack_assignment_gradients( + recv_token_indices=recv_token_indices, + recv_assignment_indices=recv_assignment_indices, + local_group_sizes=local_group_sizes, + recv_x_shape=recv_x_shape, + recv_topk_weights_shape=recv_topk_weights_shape, + cotangents=cotangents, + ) + return grad_recv_x, None, grad_recv_topk_weights, None, None + + +_pack_local_assignments_from_counts_with_vjp.defvjp( + _pack_local_assignments_from_counts_with_vjp_fwd, + _pack_local_assignments_from_counts_with_vjp_bwd, +) + + +def deepep_collapse_local_assignments( + out_dispatch: jax.Array, + assignment_weights: jax.Array, + recv_token_indices: jax.Array, + assignment_destinations: jax.Array, + local_group_sizes: jax.Array, + num_recv_tokens: jax.Array, + *, + recv_capacity: int, +) -> jax.Array: + return _collapse_local_assignments_with_vjp( + out_dispatch, + assignment_weights, + recv_token_indices, + assignment_destinations, + local_group_sizes, + num_recv_tokens, + recv_capacity, + ) + + +def _collapse_local_assignments_impl( + out_dispatch: jax.Array, + assignment_weights: jax.Array, + recv_token_indices: jax.Array, + assignment_destinations: jax.Array, + local_group_sizes: jax.Array, + num_recv_tokens: jax.Array, + *, + recv_capacity: int, +) -> jax.Array: + _register_targets() + if recv_capacity <= 0: + raise ValueError(f"recv_capacity must be positive, got {recv_capacity}") + out_dispatch_bf16 = jnp.asarray(out_dispatch, dtype=jnp.bfloat16) + assignment_weights_bf16 = jnp.asarray(assignment_weights, dtype=jnp.bfloat16) + del recv_token_indices + assignment_destinations_i32 = jnp.asarray(assignment_destinations, dtype=jnp.int32) + local_group_sizes_i32 = jnp.asarray(local_group_sizes, dtype=jnp.int32) + accepted_total_i32 = jnp.reshape(jnp.sum(local_group_sizes_i32, dtype=jnp.int32), (1,)) + num_recv_tokens_i32 = jnp.asarray(num_recv_tokens, dtype=jnp.int32) + if num_recv_tokens_i32.ndim == 0: + num_recv_tokens_i32 = jnp.reshape(num_recv_tokens_i32, (1,)) + + result_shape_dtype = jax.ShapeDtypeStruct((recv_capacity, out_dispatch_bf16.shape[1]), out_dispatch_bf16.dtype) + recv_out = jax.ffi.ffi_call( + _COLLAPSE_LOCAL_ASSIGNMENTS_TARGET, + result_shape_dtype, + has_side_effect=True, + vmap_method="broadcast_all", + )( + out_dispatch_bf16, + assignment_weights_bf16, + assignment_destinations_i32, + accepted_total_i32, + num_recv_tokens_i32, + ) + return recv_out + + +@partial(jax.custom_vjp, nondiff_argnums=(6,)) +def _collapse_local_assignments_with_vjp( + out_dispatch: jax.Array, + assignment_weights: jax.Array, + recv_token_indices: jax.Array, + assignment_destinations: jax.Array, + local_group_sizes: jax.Array, + num_recv_tokens: jax.Array, + recv_capacity: int, +) -> jax.Array: + return _collapse_local_assignments_impl( + out_dispatch, + assignment_weights, + recv_token_indices, + assignment_destinations, + local_group_sizes, + num_recv_tokens, + recv_capacity=recv_capacity, + ) + + +def _collapse_local_assignments_with_vjp_fwd( + out_dispatch: jax.Array, + assignment_weights: jax.Array, + recv_token_indices: jax.Array, + assignment_destinations: jax.Array, + local_group_sizes: jax.Array, + num_recv_tokens: jax.Array, + recv_capacity: int, +): + output = _collapse_local_assignments_impl( + out_dispatch, + assignment_weights, + recv_token_indices, + assignment_destinations, + local_group_sizes, + num_recv_tokens, + recv_capacity=recv_capacity, + ) + return output, (out_dispatch, assignment_weights, recv_token_indices, local_group_sizes) + + +def _collapse_local_assignments_with_vjp_bwd(recv_capacity: int, residuals, cotangent): + out_dispatch, assignment_weights, recv_token_indices, local_group_sizes = residuals + valid_assignments = jnp.arange(assignment_weights.shape[0], dtype=jnp.int32) < jnp.sum(local_group_sizes) + safe_recv_token_indices = jnp.where(valid_assignments, recv_token_indices, 0) + grad_recv_out = _materialize_cotangent( + cotangent, + dtype=jnp.bfloat16, + shape=(recv_capacity, out_dispatch.shape[1]), + ) + gathered_grad = jnp.take(grad_recv_out, safe_recv_token_indices, axis=0) + gathered_grad = jnp.where(valid_assignments[:, None], gathered_grad, 0) + out_dispatch = jnp.where(valid_assignments[:, None], out_dispatch, 0) + grad_out_dispatch = gathered_grad * assignment_weights[:, None].astype(gathered_grad.dtype) + grad_assignment_weights = jnp.sum(gathered_grad.astype(jnp.float32) * out_dispatch.astype(jnp.float32), axis=1) + return grad_out_dispatch, grad_assignment_weights, None, None, None, None + + +_collapse_local_assignments_with_vjp.defvjp( + _collapse_local_assignments_with_vjp_fwd, + _collapse_local_assignments_with_vjp_bwd, +) + + @partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8, 9)) def _dispatch_intranode_with_vjp( x: jax.Array, @@ -1062,7 +1760,7 @@ def _dispatch_intranode_with_vjp_fwd( ) ( recv_x, - _, + _recv_topk_idx, recv_topk_weights, recv_src_idx, rank_prefix_matrix, @@ -1125,6 +1823,177 @@ def _dispatch_intranode_with_vjp_bwd( ) +@partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8, 9)) +def _dispatch_intranode_with_assignments_with_vjp( + x: jax.Array, + topk_idx: jax.Array, + topk_weights: jax.Array, + num_tokens_per_rank: jax.Array, + num_tokens_per_expert: jax.Array, + is_token_in_rank: jax.Array, + num_experts: int, + dispatch_config: IntranodeConfig | None, + combine_config: IntranodeConfig | None, + max_recv_tokens: int | None, +) -> tuple[ + jax.Array, + jax.Array, + jax.Array, + jax.Array, + jax.Array, + jax.Array, + jax.Array, + jax.Array, + jax.Array, + jax.Array, + jax.Array, + jax.Array, + jax.Array, + jax.Array, +]: + return _dispatch_intranode_with_assignments_impl( + x, + topk_idx, + topk_weights, + num_tokens_per_rank, + num_tokens_per_expert, + is_token_in_rank, + num_experts=num_experts, + dispatch_config=dispatch_config, + combine_config=combine_config, + max_recv_tokens=max_recv_tokens, + ) + + +def _dispatch_intranode_with_assignments_with_vjp_fwd( + x: jax.Array, + topk_idx: jax.Array, + topk_weights: jax.Array, + num_tokens_per_rank: jax.Array, + num_tokens_per_expert: jax.Array, + is_token_in_rank: jax.Array, + num_experts: int, + dispatch_config: IntranodeConfig | None, + combine_config: IntranodeConfig | None, + max_recv_tokens: int | None, +): + outputs = _dispatch_intranode_with_assignments_impl( + x, + topk_idx, + topk_weights, + num_tokens_per_rank, + num_tokens_per_expert, + is_token_in_rank, + num_experts=num_experts, + dispatch_config=dispatch_config, + combine_config=combine_config, + max_recv_tokens=max_recv_tokens, + ) + ( + recv_x, + recv_topk_weights, + recv_src_idx, + rank_prefix_matrix, + _, + recv_channel_prefix_matrix, + send_head, + local_group_sizes, + num_recv_tokens, + _, + _, + recv_token_indices, + recv_assignment_indices, + _assignment_destinations, + ) = outputs + residuals = ( + recv_x, + recv_topk_weights, + recv_src_idx, + rank_prefix_matrix, + recv_channel_prefix_matrix, + send_head, + local_group_sizes, + num_recv_tokens, + recv_token_indices, + recv_assignment_indices, + ) + return outputs, residuals + + +def _dispatch_intranode_with_assignments_with_vjp_bwd( + num_experts: int, + dispatch_config: IntranodeConfig | None, + combine_config: IntranodeConfig | None, + max_recv_tokens: int | None, + residuals, + cotangents, +): + del num_experts, dispatch_config, combine_config, max_recv_tokens + ( + recv_x, + recv_topk_weights, + recv_src_idx, + rank_prefix_matrix, + recv_channel_prefix_matrix, + send_head, + local_group_sizes, + num_recv_tokens, + recv_token_indices, + recv_assignment_indices, + ) = residuals + valid_assignments = jnp.arange(recv_assignment_indices.shape[0], dtype=jnp.int32) < jnp.sum(local_group_sizes) + + grad_recv_x = _materialize_cotangent(cotangents[0], dtype=recv_x.dtype, reference=recv_x) + grad_recv_topk_weights = _materialize_cotangent( + cotangents[1], + dtype=recv_topk_weights.dtype, + reference=recv_topk_weights, + ) + + grad_x_dispatch = _materialize_cotangent( + cotangents[9], + dtype=recv_x.dtype, + shape=(recv_assignment_indices.shape[0], recv_x.shape[1]), + ) + grad_x_dispatch = jnp.where(valid_assignments[:, None], grad_x_dispatch, 0) + grad_assignment_weights = _materialize_cotangent( + cotangents[10], + dtype=jnp.float32, + shape=(recv_assignment_indices.shape[0],), + ) + grad_assignment_weights = jnp.where(valid_assignments, grad_assignment_weights, 0) + + grad_recv_x += jax.ops.segment_sum( + grad_x_dispatch, + recv_token_indices, + num_segments=recv_x.shape[0], + indices_are_sorted=False, + ) + grad_recv_topk_weights += jax.ops.segment_sum( + grad_assignment_weights.astype(jnp.float32), + recv_assignment_indices, + num_segments=recv_topk_weights.shape[0] * recv_topk_weights.shape[1], + indices_are_sorted=False, + ).reshape(recv_topk_weights.shape) + + grad_x, grad_topk_weights = _combine_intranode_impl( + grad_recv_x, + grad_recv_topk_weights, + recv_src_idx, + rank_prefix_matrix, + recv_channel_prefix_matrix, + send_head, + num_recv_tokens, + ) + return grad_x, None, grad_topk_weights, None, None, None + + +_dispatch_intranode_with_assignments_with_vjp.defvjp( + _dispatch_intranode_with_assignments_with_vjp_fwd, + _dispatch_intranode_with_assignments_with_vjp_bwd, +) + + @jax.custom_vjp def _combine_intranode_with_vjp( recv_x: jax.Array, @@ -1236,10 +2105,52 @@ def deepep_dispatch_intranode( dispatch_config: IntranodeConfig | None = None, combine_config: IntranodeConfig | None = None, max_recv_tokens: int | None = None, -) -> tuple[ - jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, jax.Array -]: - return _dispatch_intranode_with_vjp( +) -> DeepEPDispatch: + return DeepEPDispatch( + *_dispatch_intranode_with_vjp( + x, + topk_idx, + topk_weights, + num_tokens_per_rank, + num_tokens_per_expert, + is_token_in_rank, + num_experts, + dispatch_config, + combine_config, + max_recv_tokens, + ) + ) + + +def deepep_dispatch_intranode_with_assignments( + x: jax.Array, + topk_idx: jax.Array, + topk_weights: jax.Array, + num_tokens_per_rank: jax.Array, + num_tokens_per_expert: jax.Array, + is_token_in_rank: jax.Array, + *, + num_experts: int, + dispatch_config: IntranodeConfig | None = None, + combine_config: IntranodeConfig | None = None, + max_recv_tokens: int | None = None, +) -> DeepEPDispatchWithAssignments: + ( + recv_x, + recv_topk_weights, + recv_src_idx, + rank_prefix_matrix, + channel_prefix_matrix, + recv_channel_prefix_matrix, + send_head, + local_group_sizes, + num_recv_tokens, + x_dispatch, + assignment_weights, + recv_token_indices, + _recv_assignment_indices, + assignment_destinations, + ) = _dispatch_intranode_with_assignments_with_vjp( x, topk_idx, topk_weights, @@ -1251,6 +2162,21 @@ def deepep_dispatch_intranode( combine_config, max_recv_tokens, ) + return DeepEPDispatchWithAssignments( + recv_x=recv_x, + recv_topk_weights=recv_topk_weights, + recv_src_idx=recv_src_idx, + rank_prefix_matrix=rank_prefix_matrix, + channel_prefix_matrix=channel_prefix_matrix, + recv_channel_prefix_matrix=recv_channel_prefix_matrix, + send_head=send_head, + local_group_sizes=local_group_sizes, + num_recv_tokens=num_recv_tokens, + x_dispatch=x_dispatch, + assignment_weights=assignment_weights, + recv_token_indices=recv_token_indices, + assignment_destinations=assignment_destinations, + ) def deepep_combine_intranode( diff --git a/lib/levanter/src/levanter/kernels/pallas/fused_cross_entropy_loss/tuned_block_sizes.py b/lib/levanter/src/levanter/kernels/pallas/fused_cross_entropy_loss/tuned_block_sizes.py index e894625dea..baf290efc0 100644 --- a/lib/levanter/src/levanter/kernels/pallas/fused_cross_entropy_loss/tuned_block_sizes.py +++ b/lib/levanter/src/levanter/kernels/pallas/fused_cross_entropy_loss/tuned_block_sizes.py @@ -50,8 +50,10 @@ def matches(self, b: int, h: int, v: int) -> bool: ("float32", "llama3-ish"): BlockSizes(b_block_size=256, h_block_size=256, v_block_size=1024), ("bfloat16", "large-batch-small-h"): BlockSizes(b_block_size=256, h_block_size=256, v_block_size=2048), ("float32", "large-batch-small-h"): BlockSizes(b_block_size=256, h_block_size=256, v_block_size=2048), - ("bfloat16", "medium-batch-medium-h"): BlockSizes(b_block_size=256, h_block_size=256, v_block_size=2048), - ("float32", "medium-batch-medium-h"): BlockSizes(b_block_size=256, h_block_size=256, v_block_size=2048), + ("bfloat16", "medium-batch-medium-h"): BlockSizes(b_block_size=1024, h_block_size=64, v_block_size=256), + ("float32", "medium-batch-medium-h"): BlockSizes(b_block_size=1024, h_block_size=64, v_block_size=256), + ("bfloat16", "large-batch-medium-h"): BlockSizes(b_block_size=1024, h_block_size=64, v_block_size=256), + ("float32", "large-batch-medium-h"): BlockSizes(b_block_size=1024, h_block_size=64, v_block_size=256), }, "NVIDIA GB10": { ("bfloat16", "tiny"): BlockSizes(b_block_size=128, h_block_size=64, v_block_size=128), @@ -74,8 +76,10 @@ def matches(self, b: int, h: int, v: int) -> bool: ("float32", "llama3-ish"): BlockSizes(b_block_size=256, h_block_size=256, v_block_size=1024), ("bfloat16", "large-batch-small-h"): BlockSizes(b_block_size=256, h_block_size=256, v_block_size=2048), ("float32", "large-batch-small-h"): BlockSizes(b_block_size=256, h_block_size=256, v_block_size=2048), - ("bfloat16", "medium-batch-medium-h"): BlockSizes(b_block_size=256, h_block_size=256, v_block_size=2048), - ("float32", "medium-batch-medium-h"): BlockSizes(b_block_size=256, h_block_size=256, v_block_size=2048), + ("bfloat16", "medium-batch-medium-h"): BlockSizes(b_block_size=1024, h_block_size=64, v_block_size=256), + ("float32", "medium-batch-medium-h"): BlockSizes(b_block_size=1024, h_block_size=64, v_block_size=256), + ("bfloat16", "large-batch-medium-h"): BlockSizes(b_block_size=1024, h_block_size=64, v_block_size=256), + ("float32", "large-batch-medium-h"): BlockSizes(b_block_size=1024, h_block_size=64, v_block_size=256), }, "NVIDIA A100": { ("bfloat16", "tiny"): BlockSizes(b_block_size=128, h_block_size=64, v_block_size=128), @@ -86,8 +90,10 @@ def matches(self, b: int, h: int, v: int) -> bool: ("float32", "llama3-ish"): BlockSizes(b_block_size=256, h_block_size=256, v_block_size=1024), ("bfloat16", "large-batch-small-h"): BlockSizes(b_block_size=256, h_block_size=256, v_block_size=2048), ("float32", "large-batch-small-h"): BlockSizes(b_block_size=256, h_block_size=256, v_block_size=2048), - ("bfloat16", "medium-batch-medium-h"): BlockSizes(b_block_size=256, h_block_size=256, v_block_size=2048), - ("float32", "medium-batch-medium-h"): BlockSizes(b_block_size=256, h_block_size=256, v_block_size=2048), + ("bfloat16", "medium-batch-medium-h"): BlockSizes(b_block_size=1024, h_block_size=64, v_block_size=256), + ("float32", "medium-batch-medium-h"): BlockSizes(b_block_size=1024, h_block_size=64, v_block_size=256), + ("bfloat16", "large-batch-medium-h"): BlockSizes(b_block_size=1024, h_block_size=64, v_block_size=256), + ("float32", "large-batch-medium-h"): BlockSizes(b_block_size=1024, h_block_size=64, v_block_size=256), }, "TPU v5e": { ("bfloat16", "small-vocab"): BlockSizes(b_block_size=1024, h_block_size=256, v_block_size=1024), diff --git a/lib/levanter/src/levanter/main/train_lm.py b/lib/levanter/src/levanter/main/train_lm.py index 8ffd1fe2fa..3e2b4c99a8 100644 --- a/lib/levanter/src/levanter/main/train_lm.py +++ b/lib/levanter/src/levanter/main/train_lm.py @@ -4,6 +4,7 @@ import dataclasses import gc import logging +import os from dataclasses import dataclass, field from typing import Optional diff --git a/lib/levanter/src/levanter/tokenizers.py b/lib/levanter/src/levanter/tokenizers.py index fd4cb2c357..70e327f206 100644 --- a/lib/levanter/src/levanter/tokenizers.py +++ b/lib/levanter/src/levanter/tokenizers.py @@ -25,7 +25,7 @@ import tempfile import time from enum import StrEnum -from typing import Any, Protocol, runtime_checkable +from typing import Any, Protocol, cast, runtime_checkable import fsspec import jinja2 @@ -198,8 +198,9 @@ def _make_jinja_env(extensions: list[type]) -> jinja2.Environment: lstrip_blocks=True, extensions=extensions, ) - env.globals["raise_exception"] = _raise_chat_template_exception - env.globals["strftime_now"] = lambda fmt: time.strftime(fmt) + env_globals = cast(dict[str, Any], env.globals) + env_globals["raise_exception"] = _raise_chat_template_exception + env_globals["strftime_now"] = lambda fmt: time.strftime(fmt) return env diff --git a/lib/levanter/tests/grug/test_attention.py b/lib/levanter/tests/grug/test_attention.py index 152a1dc941..cf1f524ee2 100644 --- a/lib/levanter/tests/grug/test_attention.py +++ b/lib/levanter/tests/grug/test_attention.py @@ -1,6 +1,8 @@ # Copyright The Levanter Authors # SPDX-License-Identifier: Apache-2.0 +from types import SimpleNamespace + import equinox as eqx import jax import jax.numpy as jnp @@ -111,12 +113,16 @@ def build_mask(q_ids, kv_ids): def test_gpu_fa4_thd_registered_backend_jits_with_cutlass_boundary(monkeypatch): - def fake_fwd(q, k, v, cu_seqlens, *, softmax_scale, kernel_config): + seen_sliding_windows = [] + + def fake_fwd(q, k, v, cu_seqlens, *, softmax_scale, kernel_config, sliding_window): del k, v, cu_seqlens, softmax_scale, kernel_config + seen_sliding_windows.append(sliding_window) return q * jnp.asarray(2, dtype=q.dtype), jnp.zeros((q.shape[1], q.shape[0]), dtype=jnp.float32) - def fake_bwd(q, k, v, out, dout, lse, cu_seqlens, *, softmax_scale, kernel_config): + def fake_bwd(q, k, v, out, dout, lse, cu_seqlens, *, softmax_scale, kernel_config, sliding_window): del out, lse, cu_seqlens, softmax_scale, kernel_config + seen_sliding_windows.append(sliding_window) return ( dout * jnp.asarray(2, dtype=dout.dtype), jnp.zeros_like(k), @@ -140,13 +146,57 @@ def fake_bwd(q, k, v, out, dout, lse, cu_seqlens, *, softmax_scale, kernel_confi k = jnp.ones((2, 4, 1, 8), dtype=jnp.float32) v = jnp.ones((2, 4, 1, 8), dtype=jnp.float32) segment_ids = jnp.array([[0, 0, 1, 1], [2, 2, 3, 3]], dtype=jnp.int32) - mask = AttentionMask.causal().with_segment_ids(segment_ids, max_segments=2) + mask = AttentionMask.causal(sliding_window=3).with_segment_ids(segment_ids, max_segments=2) out = jax.jit(lambda q_arg: attention(q_arg, k, v, mask, implementation="gpu_fa4_thd"))(q) np.testing.assert_array_equal(out, jnp.full_like(q, 2)) grad = jax.jit(jax.grad(lambda q_arg: jnp.sum(attention(q_arg, k, v, mask, implementation="gpu_fa4_thd"))))(q) np.testing.assert_array_equal(grad, jnp.full_like(q, 2)) + assert seen_sliding_windows + assert all(sliding_window == 3 for sliding_window in seen_sliding_windows) + + +def test_gpu_fa4_thd_forward_launcher_threads_local_window_arguments(): + calls = {} + + class FakeFlashForward: + def __init__(self, *args, **kwargs): + calls["init_args"] = args + calls["init_kwargs"] = kwargs + + def __call__(self, *args): + calls["call_args"] = args + + modules = SimpleNamespace( + cutlass=SimpleNamespace(Float32=float), + cute=SimpleNamespace(Tensor=object, jit=lambda fn: fn), + cuda=SimpleNamespace(CUstream=object), + FlashAttentionForwardSm100=FakeFlashForward, + ) + + launcher = fa4_thd._upstream_fa4_thd_forward_launcher( + modules, + head_dim=128, + head_dim_v=128, + qhead_per_kvhead=4, + kernel_config=fa4_thd.Flash4CuteKernelConfig( + forward_tile=(128, 128), + backward_tile=(128, 128), + num_threads=384, + ), + sliding_window=2048, + ) + launcher("stream", "q", "k", "v", "cu", "out", "lse", softmax_scale=1.0) + + assert calls["init_kwargs"]["is_causal"] is False + assert calls["init_kwargs"]["is_local"] is True + call_args = calls["call_args"] + assert len(call_args) == 18 + assert call_args[10] is None + assert call_args[11] == 2047 + assert call_args[12] == 0 + assert call_args[17] == "stream" def test_gpu_fa4_thd_rejects_mha_before_kernel_config(monkeypatch): diff --git a/lib/levanter/tests/grug/test_fa4_cute_attention.py b/lib/levanter/tests/grug/test_fa4_cute_attention.py index e57fb25eea..df0e769539 100644 --- a/lib/levanter/tests/grug/test_fa4_cute_attention.py +++ b/lib/levanter/tests/grug/test_fa4_cute_attention.py @@ -11,6 +11,7 @@ gpu_fa4_cute_attention, reference_attention, ) +from levanter.grug.attention._fa4_cute import _simple_causal_lower_bounds def _make_qkv(*, batch: int = 2, q_len: int = 6, k_len: int = 6, q_heads: int = 4, kv_heads: int = 2): @@ -37,6 +38,29 @@ def test_fa4_frontend_rejects_mismatched_q_kv_segment_ids(): jax.block_until_ready(gpu_fa4_cute_attention(q, k, v, mask)) +def test_simple_causal_lower_bounds_match_sliding_window_semantics(): + lower_bounds, valid = _simple_causal_lower_bounds(batch_size=2, seq_len=6, sliding_window=3) + + np.testing.assert_array_equal( + lower_bounds, + np.array( + [ + [0, 0, 0, 1, 2, 3], + [0, 0, 0, 1, 2, 3], + ], + dtype=np.int32, + ), + ) + np.testing.assert_array_equal(valid, np.ones((2, 6), dtype=np.bool_)) + + +def test_simple_causal_lower_bounds_match_full_causal_semantics(): + lower_bounds, valid = _simple_causal_lower_bounds(batch_size=2, seq_len=4, sliding_window=None) + + np.testing.assert_array_equal(lower_bounds, np.zeros((2, 4), dtype=np.int32)) + np.testing.assert_array_equal(valid, np.ones((2, 4), dtype=np.bool_)) + + @pytest.mark.parametrize(("q_heads", "kv_heads"), [(4, 1), (2, 2)]) def test_real_gpu_fa4_cute_attention_matches_reference_for_valid_dynamic_packed_segments(q_heads, kv_heads): if jax.default_backend() != "gpu": @@ -82,3 +106,38 @@ def fa4_loss(q_arg, k_arg, v_arg): for actual_grad, expected_grad in zip(actual_grads, expected_grads, strict=True): np.testing.assert_allclose(actual_grad, expected_grad, atol=7e-2, rtol=7e-2) + + +def test_real_gpu_fa4_cute_attention_matches_reference_for_simple_sliding_mask(): + if jax.default_backend() != "gpu": + pytest.skip("FA4/CuTe correctness requires a GPU backend.") + pytest.importorskip("cutlass") + pytest.importorskip("cutlass.cute") + pytest.importorskip("flash_attn.cute.flash_bwd_preprocess") + key = jax.random.PRNGKey(5) + q_key, k_key, v_key, cotangent_key = jax.random.split(key, 4) + q = jax.random.normal(q_key, (2, 64, 4, 64), dtype=jnp.bfloat16) + k = jax.random.normal(k_key, (2, 64, 2, 64), dtype=jnp.bfloat16) + v = jax.random.normal(v_key, (2, 64, 2, 64), dtype=jnp.bfloat16) + mask = AttentionMask.causal(sliding_window=7) + + actual = jax.jit(gpu_fa4_cute_attention)(q, k, v, mask) + expected = reference_attention(q, k, v, mask, logits_dtype=jnp.float32) + + np.testing.assert_allclose(actual, expected, atol=7e-2, rtol=7e-2) + + cotangent = jax.random.normal(cotangent_key, q.shape, dtype=jnp.bfloat16) + + def ref_loss(q_arg, k_arg, v_arg): + out = reference_attention(q_arg, k_arg, v_arg, mask, logits_dtype=jnp.float32) + return jnp.sum(out.astype(jnp.float32) * cotangent.astype(jnp.float32)) + + def fa4_loss(q_arg, k_arg, v_arg): + out = gpu_fa4_cute_attention(q_arg, k_arg, v_arg, mask) + return jnp.sum(out.astype(jnp.float32) * cotangent.astype(jnp.float32)) + + actual_grads = jax.jit(jax.grad(fa4_loss, argnums=(0, 1, 2)))(q, k, v) + expected_grads = jax.jit(jax.grad(ref_loss, argnums=(0, 1, 2)))(q, k, v) + + for actual_grad, expected_grad in zip(actual_grads, expected_grads, strict=True): + np.testing.assert_allclose(actual_grad, expected_grad, atol=7e-2, rtol=7e-2) diff --git a/lib/levanter/tests/grug/test_grugformer_moe.py b/lib/levanter/tests/grug/test_grugformer_moe.py index f5534b5178..220bfd5cf2 100644 --- a/lib/levanter/tests/grug/test_grugformer_moe.py +++ b/lib/levanter/tests/grug/test_grugformer_moe.py @@ -2,6 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 import importlib.util +import json +import time import numpy as np import pytest @@ -14,7 +16,6 @@ import levanter.grug.grug_moe as grug_moe from levanter.grug._moe.common import _prepare_moe_dispatch, _prepare_moe_dispatch_indices_with_assignment_ids -from levanter.grug._moe.ep_deepep import _pack_deepep_local_assignments from levanter.grug._moe.sonic import sonic_gather_sum from levanter.grug.grug_moe import ( MoEExpertMlp, @@ -52,6 +53,18 @@ def _make_ep_mesh_or_none() -> Mesh | None: ) +def _make_ep_mesh_for_expert_axis_or_none(expert_axis_size: int) -> Mesh | None: + devices = jax.devices() + if len(devices) < expert_axis_size or len(devices) % expert_axis_size != 0: + return None + mesh_devices = np.array(devices).reshape(len(devices) // expert_axis_size, expert_axis_size, 1) + return Mesh( + mesh_devices, + axis_names=("data", "expert", "model"), + axis_types=(AxisType.Explicit, AxisType.Explicit, AxisType.Explicit), + ) + + def _make_abstract_moe_mesh(*, data: int, expert: int, model: int) -> AbstractMesh: return AbstractMesh( axis_sizes=(data, expert, model), @@ -172,61 +185,6 @@ def test_moe_mlp_default_matches_explicit_ring_without_ep_axis(): np.testing.assert_allclose(np.asarray(y_default), np.asarray(y_ring), rtol=1e-5, atol=1e-5) -def test_deepep_local_assignment_packing_uses_local_expert_ids(): - recv_x = jnp.array( - [ - [1.0, 2.0], - [3.0, 4.0], - [5.0, 6.0], - ], - dtype=jnp.float32, - ) - recv_topk_idx = jnp.array( - [ - [0, 1], - [1, -1], - [0, 0], - ], - dtype=jnp.int32, - ) - recv_topk_weights = jnp.array( - [ - [0.1, 0.2], - [0.3, 0.0], - [0.4, 0.5], - ], - dtype=jnp.float32, - ) - - local_assignments = _pack_deepep_local_assignments( - recv_x, - recv_topk_idx, - recv_topk_weights, - local_experts=2, - num_recv_tokens=jnp.array(2, dtype=jnp.int32), - ) - - np.testing.assert_array_equal(np.asarray(local_assignments.local_group_sizes), np.array([1, 2], dtype=np.int32)) - np.testing.assert_array_equal( - np.asarray(local_assignments.recv_token_indices[:3]), - np.array([0, 0, 1], dtype=np.int32), - ) - np.testing.assert_allclose( - np.asarray(local_assignments.x_dispatch[:3]), - np.array([[1.0, 2.0], [1.0, 2.0], [3.0, 4.0]], dtype=np.float32), - rtol=0, - atol=0, - ) - np.testing.assert_allclose( - np.asarray(local_assignments.assignment_weights[:3]), - np.array([0.1, 0.2, 0.3], dtype=np.float32), - rtol=1e-6, - atol=1e-6, - ) - np.testing.assert_allclose(np.asarray(local_assignments.x_dispatch[3:]), 0, rtol=0, atol=0) - np.testing.assert_allclose(np.asarray(local_assignments.assignment_weights[3:]), 0, rtol=0, atol=0) - - def test_prepare_moe_dispatch_indices_match_materialized_dispatch(): x, selected_experts, combine_weights, _w_up_gate, _w_down = _make_inputs( key=jax.random.key(28), @@ -430,7 +388,7 @@ def test_moe_expert_mlp_init_uses_logical_weight_pspecs(): assert mlp.w_down.sharding.spec == P(None, "model", "data") -@pytest.mark.parametrize("implementation", ["ring", "ragged_all_to_all"]) +@pytest.mark.parametrize("implementation", ["ring", "assigned_token"]) def test_moe_ep_path_lowers_on_abstract_mesh(implementation: MoeImplementation): mesh = _make_abstract_moe_mesh(data=2, expert=2, model=1) @@ -508,12 +466,12 @@ def test_shard_a2a_params_uses_sender_side_output_offsets(): np.testing.assert_array_equal(np.asarray(output_offsets), np.array([1, 7, 2], dtype=np.int32)) -def test_moe_mlp_ragged_matches_ring_with_ep_axis_when_available(): +def test_moe_mlp_assigned_token_matches_ring_with_ep_axis_when_available(): mesh = _make_ep_mesh_or_none() if mesh is None: pytest.skip("requires an even number of >=2 devices") if jax.devices()[0].platform == "cpu": - pytest.skip("ragged_all_to_all is not implemented on XLA:CPU") + pytest.skip("assigned_token is not implemented on XLA:CPU") tokens = len(jax.devices()) * 8 hidden_dim = 16 @@ -550,20 +508,195 @@ def test_moe_mlp_ragged_matches_ring_with_ep_axis_when_available(): report_capacity_overflow=True, capacity_factor=1.0, ) - ragged_out, ragged_dropped = moe_mlp( + assigned_out, assigned_dropped = moe_mlp( x, selected_experts, combine_weights, w_up_gate, w_down, - implementation="ragged_all_to_all", + implementation="assigned_token", mesh=None, report_capacity_overflow=True, capacity_factor=1.0, ) - np.testing.assert_allclose(np.asarray(ragged_out), np.asarray(ring_out), rtol=1e-5, atol=1e-5) - assert int(ragged_dropped) == int(ring_dropped) + np.testing.assert_allclose(np.asarray(assigned_out), np.asarray(ring_out), rtol=1e-5, atol=1e-5) + assert int(assigned_dropped) == int(ring_dropped) + + +def test_moe_mlp_assigned_token_backward_matches_ring_with_ep_axis_when_available(): + mesh = _make_ep_mesh_or_none() + if mesh is None: + pytest.skip("requires an even number of >=2 devices") + if jax.devices()[0].platform == "cpu": + pytest.skip("assigned_token is not implemented on XLA:CPU") + + tokens = len(jax.devices()) * 8 + # TPU Pallas GMM backward lowering requires valid block multiples for the + # ring reference path; keep this shape small but TPU-lowerable. + hidden_dim = 128 + intermediate_dim = 128 + num_experts = 4 + topk = 2 + + def loss_fn(implementation, x, selected_experts, combine_weights, w_up_gate, w_down): + out = moe_mlp( + x, + selected_experts, + combine_weights, + w_up_gate, + w_down, + implementation=implementation, + mesh=None, + capacity_factor=1.0, + ) + return jnp.sum(out.astype(jnp.float32)) + + with jax.set_mesh(mesh): + x, selected_experts, combine_weights, w_up_gate, w_down = _make_inputs( + key=jax.random.key(33), + tokens=tokens, + hidden_dim=hidden_dim, + intermediate_dim=intermediate_dim, + num_experts=num_experts, + topk=topk, + ) + + batch_sharding = NamedSharding(mesh, P(("data", "expert"), None)) + expert_sharding = NamedSharding(mesh, P("expert", None, None)) + x = jax.sharding.reshard(x, batch_sharding) + selected_experts = jax.sharding.reshard(selected_experts, batch_sharding) + combine_weights = jax.sharding.reshard(combine_weights, batch_sharding) + w_up_gate = jax.sharding.reshard(w_up_gate, expert_sharding) + w_down = jax.sharding.reshard(w_down, expert_sharding) + + ring_grads = jax.grad(loss_fn, argnums=(1, 3, 4, 5))( + "ring", + x, + selected_experts, + combine_weights, + w_up_gate, + w_down, + ) + assigned_grads = jax.grad(loss_fn, argnums=(1, 3, 4, 5))( + "assigned_token", + x, + selected_experts, + combine_weights, + w_up_gate, + w_down, + ) + + for assigned_grad, ring_grad in zip(assigned_grads, ring_grads, strict=True): + np.testing.assert_allclose(np.asarray(assigned_grad), np.asarray(ring_grad), rtol=1e-5, atol=1e-5) + + +@pytest.mark.slow +@pytest.mark.timeout(900) +def test_moe_mlp_issue6215_d2560_perf_smoke_when_available(): + mesh = _make_ep_mesh_for_expert_axis_or_none(4) + if mesh is None: + pytest.skip("requires a device count divisible by expert_axis_size=4") + if jax.devices()[0].platform != "gpu": + pytest.skip("issue6215 perf smoke requires GPUs") + if len(jax.devices()) != 4: + pytest.skip("DeepEP issue6215 perf smoke requires exactly 4 visible GPUs") + + batch_size = 8 + seq_len = 4096 + tokens = batch_size * seq_len + hidden_dim = 2560 + intermediate_dim = 1280 + num_experts = 64 + topk = 4 + + token_ids = jnp.arange(tokens, dtype=jnp.int32)[:, None] + topk_offsets = jnp.arange(topk, dtype=jnp.int32)[None, :] + selected_experts = (token_ids * topk + topk_offsets) % num_experts + k_x, k_weights, k_w13, k_w2 = jax.random.split(jax.random.key(6215), 4) + x = jax.random.normal(k_x, (tokens, hidden_dim), dtype=jnp.bfloat16) + combine_weights = jax.nn.sigmoid(jax.random.normal(k_weights, (tokens, topk), dtype=jnp.float32)).astype( + jnp.bfloat16 + ) + w_up_gate = jax.random.normal(k_w13, (num_experts, hidden_dim, 2 * intermediate_dim), dtype=jnp.bfloat16) + w_down = jax.random.normal(k_w2, (num_experts, intermediate_dim, hidden_dim), dtype=jnp.bfloat16) + + with jax.set_mesh(mesh): + batch_sharding = NamedSharding(mesh, P(("data", "expert"), None)) + expert_sharding = NamedSharding(mesh, P("expert", None, None)) + inputs = ( + jax.sharding.reshard(x, batch_sharding), + jax.sharding.reshard(selected_experts, batch_sharding), + jax.sharding.reshard(combine_weights, batch_sharding), + jax.sharding.reshard(w_up_gate, expert_sharding), + jax.sharding.reshard(w_down, expert_sharding), + ) + + def run_impl(implementation: MoeImplementation): + fn = jax.jit( + lambda x, sel, cw, up_gate, down: moe_mlp( + x, + sel, + cw, + up_gate, + down, + activation=ActivationFunctionEnum.silu, + implementation=implementation, + mesh=mesh, + capacity_factor=1.25, + report_capacity_overflow=True, + ) + ) + compiled_out, compiled_dropped = fn(*inputs) + compiled_out.block_until_ready() + compiled_dropped.block_until_ready() + start = time.perf_counter() + out, dropped = fn(*inputs) + out.block_until_ready() + dropped.block_until_ready() + return out, dropped, time.perf_counter() - start + + ring_out, ring_dropped, ring_seconds = run_impl("ring") + deepep_out, deepep_dropped, deepep_seconds = run_impl("deepep") + assigned_out, assigned_dropped, assigned_seconds = run_impl("assigned_token") + + max_reference_abs = float(jnp.max(jnp.abs(ring_out.astype(jnp.float32)))) + deepep_max_abs_diff = float(jnp.max(jnp.abs(ring_out.astype(jnp.float32) - deepep_out.astype(jnp.float32)))) + assigned_max_abs_diff = float(jnp.max(jnp.abs(ring_out.astype(jnp.float32) - assigned_out.astype(jnp.float32)))) + payload = { + "shape": { + "batch_size": batch_size, + "seq_len": seq_len, + "hidden_dim": hidden_dim, + "intermediate_dim": intermediate_dim, + "num_experts": num_experts, + "topk": topk, + }, + "results": { + "ring": { + "seconds": ring_seconds, + "tokens_per_second": tokens / ring_seconds, + }, + "deepep": { + "seconds": deepep_seconds, + "tokens_per_second": tokens / deepep_seconds, + "max_abs_diff": deepep_max_abs_diff, + "dropped": int(deepep_dropped), + }, + "assigned_token": { + "seconds": assigned_seconds, + "tokens_per_second": tokens / assigned_seconds, + "max_abs_diff": assigned_max_abs_diff, + "dropped": int(assigned_dropped), + }, + }, + "max_reference_abs": max_reference_abs, + } + print("ISSUE6215_PERF_SMOKE " + json.dumps(payload, sort_keys=True)) + assert ring_out.shape == deepep_out.shape == assigned_out.shape == (tokens, hidden_dim) + assert int(ring_dropped) == int(deepep_dropped) == int(assigned_dropped) + assert deepep_max_abs_diff / max(max_reference_abs, 1.0) < 0.02 + assert assigned_max_abs_diff / max(max_reference_abs, 1.0) < 0.02 def test_moe_mlp_runs_with_ep_axis_when_available(): @@ -607,18 +740,18 @@ def test_moe_mlp_runs_with_ep_axis_when_available(): assert out.shape == (tokens, hidden_dim) assert jnp.isfinite(out).all() - out_ragged = moe_mlp( + out_assigned = moe_mlp( x, selected_experts, combine_weights, w_up_gate, w_down, activation=ActivationFunctionEnum.silu, - implementation="ragged_all_to_all", + implementation="assigned_token", mesh=None, ) - assert out_ragged.shape == (tokens, hidden_dim) - assert jnp.isfinite(out_ragged).all() + assert out_assigned.shape == (tokens, hidden_dim) + assert jnp.isfinite(out_assigned).all() def test_functional_moe_mlp_accepts_enum_and_callable_activation(): @@ -761,7 +894,7 @@ def test_moe_mlp_reports_positive_drop_count_in_ring_ep_when_over_capacity(): assert int(dropped) > 0 -def test_moe_mlp_reports_positive_drop_count_in_ragged_a2a_when_over_capacity(): +def test_moe_mlp_reports_positive_drop_count_in_assigned_token_when_over_capacity(): mesh = _make_ep_mesh_or_none() if mesh is None: pytest.skip("requires an even number of >=2 devices") @@ -796,7 +929,7 @@ def test_moe_mlp_reports_positive_drop_count_in_ragged_a2a_when_over_capacity(): combine_weights, w_up_gate, w_down, - implementation="ragged_all_to_all", + implementation="assigned_token", mesh=None, report_capacity_overflow=True, ) @@ -806,7 +939,7 @@ def test_moe_mlp_reports_positive_drop_count_in_ragged_a2a_when_over_capacity(): assert int(dropped) > 0 -def test_ragged_a2a_receiver_clipping_respects_capacity(): +def test_assigned_token_receiver_clipping_respects_capacity(): group_sizes = jnp.array( [ [3, 1, 0, 0], diff --git a/lib/levanter/tests/kernels/test_pallas_fused_cross_entropy_loss.py b/lib/levanter/tests/kernels/test_pallas_fused_cross_entropy_loss.py index 2b6c619e07..29bbeaf7ef 100644 --- a/lib/levanter/tests/kernels/test_pallas_fused_cross_entropy_loss.py +++ b/lib/levanter/tests/kernels/test_pallas_fused_cross_entropy_loss.py @@ -1631,6 +1631,23 @@ def test_infer_block_sizes_uses_widest_operand_dtype_bucket(monkeypatch: pytest. assert mixed_block_sizes != bf16_block_sizes +@pytest.mark.parametrize("batch", [16_384, 131_072]) +def test_infer_block_sizes_nvidia_issue6044_medium_h_fits_gpu_shared_memory(batch: int): + block_sizes, has_tuned_match = tuned_block_sizes.infer_block_sizes_with_tuned_match( + batch, + 2_560, + 128_256, + dtype=jnp.float32, + x_dtype=jnp.bfloat16, + w_dtype=jnp.bfloat16, + device_kind="NVIDIA B200", + ) + + assert has_tuned_match is True + assert block_sizes == fused_api.BlockSizes(b_block_size=1024, h_block_size=64, v_block_size=256) + assert block_sizes.h_block_size * block_sizes.v_block_size * jnp.dtype(jnp.float32).itemsize <= 101_376 + + def test_shape_bucket_name_large_batch_medium_h_boundary(): assert ( tuned_block_sizes.shape_bucket_name(32_767, 2_048, 128_256, device_kind="TPU v5p") == "medium-batch-medium-h" diff --git a/lib/marin/src/marin/evaluation/trace_labeled_eval.py b/lib/marin/src/marin/evaluation/trace_labeled_eval.py index 0176a6c8eb..a6d43eda40 100644 --- a/lib/marin/src/marin/evaluation/trace_labeled_eval.py +++ b/lib/marin/src/marin/evaluation/trace_labeled_eval.py @@ -433,6 +433,7 @@ def _completed_dataset_metrics(results: dict[str, object]) -> dict[str, float]: for dataset_result in _dataset_results(results).values(): if not _is_completed_dataset_result(dataset_result): continue + assert isinstance(dataset_result, Mapping) dataset_metrics = dataset_result["metrics"] assert isinstance(dataset_metrics, Mapping) for metric_name, metric_value in dataset_metrics.items(): diff --git a/pyproject.toml b/pyproject.toml index f7a5d90ec6..76379dc462 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -365,6 +365,7 @@ markers = [ "integration: mark tests as integration tests that require cluster infrastructure", "data_integration: mark tests that require external data fixtures (e.g. HuggingFace datasets)", "requires_cluster: mark tests that need a running iris cluster", + "timeout: override the default per-test timeout", ] testpaths = ["tests", "experiments"] diff --git a/tests/test_grug_variant_contracts.py b/tests/test_grug_variant_contracts.py index 2985b14ec2..13c47ae43c 100644 --- a/tests/test_grug_variant_contracts.py +++ b/tests/test_grug_variant_contracts.py @@ -166,7 +166,7 @@ def test_grug_moe_variant_threads_moe_implementation_to_kernel(): raise AssertionError("experiments.grug.moe.model must define debug_mesh_and_token_pspec") cfg = _small_model_config(model_module.GrugModelConfig, vocab_size=1024, seq_len=4) - cfg = dataclasses.replace(cfg, moe_implementation="ragged_all_to_all") + cfg = dataclasses.replace(cfg, moe_implementation="assigned_token") optimizer = optax.adam(1e-2) mp = jmp.get_policy("f32") train_step = make_train_step(optimizer, mp, z_loss_weight=0.0, ema_beta=None) @@ -189,7 +189,7 @@ def one_step(): with _reset_abstract_mesh(), use_abstract_mesh(mesh): closed_jaxpr, _, _ = eqx.filter_make_jaxpr(one_step)() - assert "ragged_all_to_all" in str(closed_jaxpr) + assert "all_gather" in str(closed_jaxpr) @pytest.mark.parametrize(