Skip to content

Commit 22dfe05

Browse files
authored
[Grug] Add DeepEP profiling named scopes (#6175)
Add named scopes around DeepEP layout, dispatch/combine transport, local packing, and collapse paths so profiler traces can attribute movement and local work separately. Part of #6139
1 parent 7dacd65 commit 22dfe05

1 file changed

Lines changed: 77 additions & 72 deletions

File tree

lib/levanter/src/levanter/grug/_moe/ep_deepep.py

Lines changed: 77 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -48,33 +48,34 @@ def _pack_deepep_local_assignments(
4848
local_experts: int,
4949
num_recv_tokens: Int[Array, ""],
5050
) -> DeepEPLocalAssignments:
51-
max_recv_tokens, topk = recv_topk_idx.shape
52-
total_assignments = max_recv_tokens * topk
53-
54-
recv_token_indices = jnp.repeat(jnp.arange(max_recv_tokens, dtype=jnp.int32), topk)
55-
expert_flat = recv_topk_idx.reshape(-1).astype(jnp.int32)
56-
recv_valid = jnp.arange(max_recv_tokens, dtype=jnp.int32) < num_recv_tokens
57-
local_mask = recv_valid[:, None] & (recv_topk_idx >= 0) & (recv_topk_idx < local_experts)
58-
local_mask_flat = local_mask.reshape(-1)
59-
local_bucket = jnp.where(local_mask_flat, expert_flat, local_experts)
60-
local_group_sizes = jnp.bincount(local_bucket, length=local_experts + 1).astype(jnp.int32)[:-1]
61-
total_valid = jnp.sum(local_group_sizes, dtype=jnp.int32)
62-
63-
flat_positions = jnp.arange(total_assignments, dtype=jnp.int32)
64-
order_key = local_bucket * total_assignments + flat_positions
65-
max_order_key = (local_experts + 1) * total_assignments
66-
selection_key = jnp.where(local_mask_flat, max_order_key - order_key, -1)
67-
_, sorted_assignment_indices = jax.lax.top_k(selection_key, total_assignments)
68-
69-
recv_token_indices = jnp.take(recv_token_indices, sorted_assignment_indices, axis=0)
70-
x_dispatch = jnp.take(recv_x, recv_token_indices, axis=0)
71-
assignment_weights = jnp.take(recv_topk_weights.reshape(-1), sorted_assignment_indices, axis=0).astype(
72-
recv_x.dtype
73-
)
74-
valid_sorted = jnp.arange(total_assignments, dtype=jnp.int32) < total_valid
75-
x_dispatch = jnp.where(valid_sorted[:, None], x_dispatch, 0)
76-
assignment_weights = jnp.where(valid_sorted, assignment_weights, 0)
77-
return DeepEPLocalAssignments(x_dispatch, assignment_weights, recv_token_indices, local_group_sizes)
51+
with jax.named_scope("deepep_pack_local_assignments"):
52+
max_recv_tokens, topk = recv_topk_idx.shape
53+
total_assignments = max_recv_tokens * topk
54+
55+
recv_token_indices = jnp.repeat(jnp.arange(max_recv_tokens, dtype=jnp.int32), topk)
56+
expert_flat = recv_topk_idx.reshape(-1).astype(jnp.int32)
57+
recv_valid = jnp.arange(max_recv_tokens, dtype=jnp.int32) < num_recv_tokens
58+
local_mask = recv_valid[:, None] & (recv_topk_idx >= 0) & (recv_topk_idx < local_experts)
59+
local_mask_flat = local_mask.reshape(-1)
60+
local_bucket = jnp.where(local_mask_flat, expert_flat, local_experts)
61+
local_group_sizes = jnp.bincount(local_bucket, length=local_experts + 1).astype(jnp.int32)[:-1]
62+
total_valid = jnp.sum(local_group_sizes, dtype=jnp.int32)
63+
64+
flat_positions = jnp.arange(total_assignments, dtype=jnp.int32)
65+
order_key = local_bucket * total_assignments + flat_positions
66+
max_order_key = (local_experts + 1) * total_assignments
67+
selection_key = jnp.where(local_mask_flat, max_order_key - order_key, -1)
68+
_, sorted_assignment_indices = jax.lax.top_k(selection_key, total_assignments)
69+
70+
recv_token_indices = jnp.take(recv_token_indices, sorted_assignment_indices, axis=0)
71+
x_dispatch = jnp.take(recv_x, recv_token_indices, axis=0)
72+
assignment_weights = jnp.take(recv_topk_weights.reshape(-1), sorted_assignment_indices, axis=0).astype(
73+
recv_x.dtype
74+
)
75+
valid_sorted = jnp.arange(total_assignments, dtype=jnp.int32) < total_valid
76+
x_dispatch = jnp.where(valid_sorted[:, None], x_dispatch, 0)
77+
assignment_weights = jnp.where(valid_sorted, assignment_weights, 0)
78+
return DeepEPLocalAssignments(x_dispatch, assignment_weights, recv_token_indices, local_group_sizes)
7879

7980

8081
def _collapse_deepep_local_assignments(
@@ -85,14 +86,15 @@ def _collapse_deepep_local_assignments(
8586
recv_capacity: int,
8687
num_recv_tokens: Int[Array, ""],
8788
) -> Float[Array, "TR D"]:
88-
recv_out = jax.ops.segment_sum(
89-
out_dispatch * assignment_weights[:, None],
90-
recv_token_indices,
91-
num_segments=recv_capacity,
92-
indices_are_sorted=False,
93-
)
94-
recv_valid = jnp.arange(recv_capacity, dtype=jnp.int32) < num_recv_tokens
95-
return jnp.where(recv_valid[:, None], recv_out, 0)
89+
with jax.named_scope("deepep_collapse_local_assignments"):
90+
recv_out = jax.ops.segment_sum(
91+
out_dispatch * assignment_weights[:, None],
92+
recv_token_indices,
93+
num_segments=recv_capacity,
94+
indices_are_sorted=False,
95+
)
96+
recv_valid = jnp.arange(recv_capacity, dtype=jnp.int32) < num_recv_tokens
97+
return jnp.where(recv_valid[:, None], recv_out, 0)
9698

9799

98100
def _moe_mlp_ep_deepep_local(
@@ -120,32 +122,34 @@ def _moe_mlp_ep_deepep_local(
120122
max_recv_tokens = x_local.shape[0] * ep_size
121123

122124
with jax.named_scope("dispatch"):
123-
num_tokens_per_rank, num_tokens_per_expert, is_token_in_rank = deepep_get_dispatch_layout(
124-
selected_experts_local,
125-
num_ranks=ep_size,
126-
num_experts=num_experts,
127-
)
128-
(
129-
recv_x,
130-
recv_topk_idx,
131-
recv_topk_weights,
132-
recv_src_idx,
133-
rank_prefix_matrix,
134-
channel_prefix_matrix,
135-
recv_channel_prefix_matrix,
136-
send_head,
137-
_local_expert_counts,
138-
num_recv_tokens,
139-
) = deepep_dispatch_intranode(
140-
x_local,
141-
selected_experts_local,
142-
combine_weights_local,
143-
num_tokens_per_rank,
144-
num_tokens_per_expert,
145-
is_token_in_rank,
146-
num_experts=num_experts,
147-
max_recv_tokens=max_recv_tokens,
148-
)
125+
with jax.named_scope("deepep_layout"):
126+
num_tokens_per_rank, num_tokens_per_expert, is_token_in_rank = deepep_get_dispatch_layout(
127+
selected_experts_local,
128+
num_ranks=ep_size,
129+
num_experts=num_experts,
130+
)
131+
with jax.named_scope("deepep_dispatch_transport"):
132+
(
133+
recv_x,
134+
recv_topk_idx,
135+
recv_topk_weights,
136+
recv_src_idx,
137+
rank_prefix_matrix,
138+
channel_prefix_matrix,
139+
recv_channel_prefix_matrix,
140+
send_head,
141+
_local_expert_counts,
142+
num_recv_tokens,
143+
) = deepep_dispatch_intranode(
144+
x_local,
145+
selected_experts_local,
146+
combine_weights_local,
147+
num_tokens_per_rank,
148+
num_tokens_per_expert,
149+
is_token_in_rank,
150+
num_experts=num_experts,
151+
max_recv_tokens=max_recv_tokens,
152+
)
149153
num_recv_tokens_scalar = jnp.squeeze(num_recv_tokens, axis=0)
150154
local_assignments = _pack_deepep_local_assignments(
151155
recv_x,
@@ -175,16 +179,17 @@ def _moe_mlp_ep_deepep_local(
175179
recv_capacity=recv_x.shape[0],
176180
num_recv_tokens=num_recv_tokens_scalar,
177181
)
178-
out_local, _ = deepep_combine_intranode(
179-
recv_out,
180-
recv_topk_weights,
181-
recv_src_idx,
182-
rank_prefix_matrix,
183-
channel_prefix_matrix,
184-
recv_channel_prefix_matrix,
185-
send_head,
186-
num_recv_tokens,
187-
is_token_in_rank,
188-
)
182+
with jax.named_scope("deepep_combine_transport"):
183+
out_local, _ = deepep_combine_intranode(
184+
recv_out,
185+
recv_topk_weights,
186+
recv_src_idx,
187+
rank_prefix_matrix,
188+
channel_prefix_matrix,
189+
recv_channel_prefix_matrix,
190+
send_head,
191+
num_recv_tokens,
192+
is_token_in_rank,
193+
)
189194
dropped_total = jnp.array(0, dtype=jnp.int32)
190195
return out_local.astype(x_local.dtype), dropped_total

0 commit comments

Comments
 (0)