From 330c4046b118cf89d3865f04ec7115bca7a7b163 Mon Sep 17 00:00:00 2001 From: David Hall Date: Thu, 4 Jun 2026 13:51:30 -0700 Subject: [PATCH] Add DeepEP profiling named scopes --- .../src/levanter/grug/_moe/ep_deepep.py | 149 +++++++++--------- 1 file changed, 77 insertions(+), 72 deletions(-) diff --git a/lib/levanter/src/levanter/grug/_moe/ep_deepep.py b/lib/levanter/src/levanter/grug/_moe/ep_deepep.py index 2dd2dd3fab..d6b4cfc391 100644 --- a/lib/levanter/src/levanter/grug/_moe/ep_deepep.py +++ b/lib/levanter/src/levanter/grug/_moe/ep_deepep.py @@ -48,33 +48,34 @@ def _pack_deepep_local_assignments( local_experts: int, num_recv_tokens: Int[Array, ""], ) -> DeepEPLocalAssignments: - max_recv_tokens, topk = recv_topk_idx.shape - total_assignments = max_recv_tokens * topk - - recv_token_indices = jnp.repeat(jnp.arange(max_recv_tokens, dtype=jnp.int32), topk) - expert_flat = recv_topk_idx.reshape(-1).astype(jnp.int32) - recv_valid = jnp.arange(max_recv_tokens, dtype=jnp.int32) < num_recv_tokens - local_mask = recv_valid[:, None] & (recv_topk_idx >= 0) & (recv_topk_idx < local_experts) - local_mask_flat = local_mask.reshape(-1) - local_bucket = jnp.where(local_mask_flat, expert_flat, local_experts) - local_group_sizes = jnp.bincount(local_bucket, length=local_experts + 1).astype(jnp.int32)[:-1] - total_valid = jnp.sum(local_group_sizes, dtype=jnp.int32) - - flat_positions = jnp.arange(total_assignments, dtype=jnp.int32) - order_key = local_bucket * total_assignments + flat_positions - max_order_key = (local_experts + 1) * total_assignments - selection_key = jnp.where(local_mask_flat, max_order_key - order_key, -1) - _, sorted_assignment_indices = jax.lax.top_k(selection_key, total_assignments) - - recv_token_indices = jnp.take(recv_token_indices, sorted_assignment_indices, axis=0) - x_dispatch = jnp.take(recv_x, recv_token_indices, axis=0) - assignment_weights = jnp.take(recv_topk_weights.reshape(-1), sorted_assignment_indices, axis=0).astype( - recv_x.dtype - ) - valid_sorted = jnp.arange(total_assignments, dtype=jnp.int32) < total_valid - x_dispatch = jnp.where(valid_sorted[:, None], x_dispatch, 0) - assignment_weights = jnp.where(valid_sorted, assignment_weights, 0) - return DeepEPLocalAssignments(x_dispatch, assignment_weights, recv_token_indices, local_group_sizes) + with jax.named_scope("deepep_pack_local_assignments"): + max_recv_tokens, topk = recv_topk_idx.shape + total_assignments = max_recv_tokens * topk + + recv_token_indices = jnp.repeat(jnp.arange(max_recv_tokens, dtype=jnp.int32), topk) + expert_flat = recv_topk_idx.reshape(-1).astype(jnp.int32) + recv_valid = jnp.arange(max_recv_tokens, dtype=jnp.int32) < num_recv_tokens + local_mask = recv_valid[:, None] & (recv_topk_idx >= 0) & (recv_topk_idx < local_experts) + local_mask_flat = local_mask.reshape(-1) + local_bucket = jnp.where(local_mask_flat, expert_flat, local_experts) + local_group_sizes = jnp.bincount(local_bucket, length=local_experts + 1).astype(jnp.int32)[:-1] + total_valid = jnp.sum(local_group_sizes, dtype=jnp.int32) + + flat_positions = jnp.arange(total_assignments, dtype=jnp.int32) + order_key = local_bucket * total_assignments + flat_positions + max_order_key = (local_experts + 1) * total_assignments + selection_key = jnp.where(local_mask_flat, max_order_key - order_key, -1) + _, sorted_assignment_indices = jax.lax.top_k(selection_key, total_assignments) + + recv_token_indices = jnp.take(recv_token_indices, sorted_assignment_indices, axis=0) + x_dispatch = jnp.take(recv_x, recv_token_indices, axis=0) + assignment_weights = jnp.take(recv_topk_weights.reshape(-1), sorted_assignment_indices, axis=0).astype( + recv_x.dtype + ) + valid_sorted = jnp.arange(total_assignments, dtype=jnp.int32) < total_valid + x_dispatch = jnp.where(valid_sorted[:, None], x_dispatch, 0) + assignment_weights = jnp.where(valid_sorted, assignment_weights, 0) + return DeepEPLocalAssignments(x_dispatch, assignment_weights, recv_token_indices, local_group_sizes) def _collapse_deepep_local_assignments( @@ -85,14 +86,15 @@ def _collapse_deepep_local_assignments( recv_capacity: int, num_recv_tokens: Int[Array, ""], ) -> Float[Array, "TR D"]: - recv_out = jax.ops.segment_sum( - out_dispatch * assignment_weights[:, None], - recv_token_indices, - num_segments=recv_capacity, - indices_are_sorted=False, - ) - recv_valid = jnp.arange(recv_capacity, dtype=jnp.int32) < num_recv_tokens - return jnp.where(recv_valid[:, None], recv_out, 0) + with jax.named_scope("deepep_collapse_local_assignments"): + recv_out = jax.ops.segment_sum( + out_dispatch * assignment_weights[:, None], + recv_token_indices, + num_segments=recv_capacity, + indices_are_sorted=False, + ) + recv_valid = jnp.arange(recv_capacity, dtype=jnp.int32) < num_recv_tokens + return jnp.where(recv_valid[:, None], recv_out, 0) def _moe_mlp_ep_deepep_local( @@ -120,32 +122,34 @@ def _moe_mlp_ep_deepep_local( max_recv_tokens = x_local.shape[0] * ep_size with jax.named_scope("dispatch"): - num_tokens_per_rank, num_tokens_per_expert, is_token_in_rank = deepep_get_dispatch_layout( - selected_experts_local, - num_ranks=ep_size, - num_experts=num_experts, - ) - ( - recv_x, - recv_topk_idx, - recv_topk_weights, - recv_src_idx, - rank_prefix_matrix, - channel_prefix_matrix, - recv_channel_prefix_matrix, - send_head, - _local_expert_counts, - num_recv_tokens, - ) = deepep_dispatch_intranode( - x_local, - selected_experts_local, - combine_weights_local, - num_tokens_per_rank, - num_tokens_per_expert, - is_token_in_rank, - num_experts=num_experts, - max_recv_tokens=max_recv_tokens, - ) + with jax.named_scope("deepep_layout"): + num_tokens_per_rank, num_tokens_per_expert, is_token_in_rank = deepep_get_dispatch_layout( + selected_experts_local, + num_ranks=ep_size, + num_experts=num_experts, + ) + with jax.named_scope("deepep_dispatch_transport"): + ( + recv_x, + recv_topk_idx, + recv_topk_weights, + recv_src_idx, + rank_prefix_matrix, + channel_prefix_matrix, + recv_channel_prefix_matrix, + send_head, + _local_expert_counts, + num_recv_tokens, + ) = deepep_dispatch_intranode( + x_local, + selected_experts_local, + combine_weights_local, + num_tokens_per_rank, + num_tokens_per_expert, + is_token_in_rank, + num_experts=num_experts, + max_recv_tokens=max_recv_tokens, + ) num_recv_tokens_scalar = jnp.squeeze(num_recv_tokens, axis=0) local_assignments = _pack_deepep_local_assignments( recv_x, @@ -175,16 +179,17 @@ def _moe_mlp_ep_deepep_local( recv_capacity=recv_x.shape[0], num_recv_tokens=num_recv_tokens_scalar, ) - out_local, _ = deepep_combine_intranode( - recv_out, - recv_topk_weights, - recv_src_idx, - rank_prefix_matrix, - channel_prefix_matrix, - recv_channel_prefix_matrix, - send_head, - num_recv_tokens, - is_token_in_rank, - ) + with jax.named_scope("deepep_combine_transport"): + out_local, _ = deepep_combine_intranode( + recv_out, + recv_topk_weights, + recv_src_idx, + rank_prefix_matrix, + channel_prefix_matrix, + recv_channel_prefix_matrix, + send_head, + num_recv_tokens, + is_token_in_rank, + ) dropped_total = jnp.array(0, dtype=jnp.int32) return out_local.astype(x_local.dtype), dropped_total