@@ -48,33 +48,34 @@ def _pack_deepep_local_assignments(
4848 local_experts : int ,
4949 num_recv_tokens : Int [Array , "" ],
5050) -> DeepEPLocalAssignments :
51- max_recv_tokens , topk = recv_topk_idx .shape
52- total_assignments = max_recv_tokens * topk
53-
54- recv_token_indices = jnp .repeat (jnp .arange (max_recv_tokens , dtype = jnp .int32 ), topk )
55- expert_flat = recv_topk_idx .reshape (- 1 ).astype (jnp .int32 )
56- recv_valid = jnp .arange (max_recv_tokens , dtype = jnp .int32 ) < num_recv_tokens
57- local_mask = recv_valid [:, None ] & (recv_topk_idx >= 0 ) & (recv_topk_idx < local_experts )
58- local_mask_flat = local_mask .reshape (- 1 )
59- local_bucket = jnp .where (local_mask_flat , expert_flat , local_experts )
60- local_group_sizes = jnp .bincount (local_bucket , length = local_experts + 1 ).astype (jnp .int32 )[:- 1 ]
61- total_valid = jnp .sum (local_group_sizes , dtype = jnp .int32 )
62-
63- flat_positions = jnp .arange (total_assignments , dtype = jnp .int32 )
64- order_key = local_bucket * total_assignments + flat_positions
65- max_order_key = (local_experts + 1 ) * total_assignments
66- selection_key = jnp .where (local_mask_flat , max_order_key - order_key , - 1 )
67- _ , sorted_assignment_indices = jax .lax .top_k (selection_key , total_assignments )
68-
69- recv_token_indices = jnp .take (recv_token_indices , sorted_assignment_indices , axis = 0 )
70- x_dispatch = jnp .take (recv_x , recv_token_indices , axis = 0 )
71- assignment_weights = jnp .take (recv_topk_weights .reshape (- 1 ), sorted_assignment_indices , axis = 0 ).astype (
72- recv_x .dtype
73- )
74- valid_sorted = jnp .arange (total_assignments , dtype = jnp .int32 ) < total_valid
75- x_dispatch = jnp .where (valid_sorted [:, None ], x_dispatch , 0 )
76- assignment_weights = jnp .where (valid_sorted , assignment_weights , 0 )
77- return DeepEPLocalAssignments (x_dispatch , assignment_weights , recv_token_indices , local_group_sizes )
51+ with jax .named_scope ("deepep_pack_local_assignments" ):
52+ max_recv_tokens , topk = recv_topk_idx .shape
53+ total_assignments = max_recv_tokens * topk
54+
55+ recv_token_indices = jnp .repeat (jnp .arange (max_recv_tokens , dtype = jnp .int32 ), topk )
56+ expert_flat = recv_topk_idx .reshape (- 1 ).astype (jnp .int32 )
57+ recv_valid = jnp .arange (max_recv_tokens , dtype = jnp .int32 ) < num_recv_tokens
58+ local_mask = recv_valid [:, None ] & (recv_topk_idx >= 0 ) & (recv_topk_idx < local_experts )
59+ local_mask_flat = local_mask .reshape (- 1 )
60+ local_bucket = jnp .where (local_mask_flat , expert_flat , local_experts )
61+ local_group_sizes = jnp .bincount (local_bucket , length = local_experts + 1 ).astype (jnp .int32 )[:- 1 ]
62+ total_valid = jnp .sum (local_group_sizes , dtype = jnp .int32 )
63+
64+ flat_positions = jnp .arange (total_assignments , dtype = jnp .int32 )
65+ order_key = local_bucket * total_assignments + flat_positions
66+ max_order_key = (local_experts + 1 ) * total_assignments
67+ selection_key = jnp .where (local_mask_flat , max_order_key - order_key , - 1 )
68+ _ , sorted_assignment_indices = jax .lax .top_k (selection_key , total_assignments )
69+
70+ recv_token_indices = jnp .take (recv_token_indices , sorted_assignment_indices , axis = 0 )
71+ x_dispatch = jnp .take (recv_x , recv_token_indices , axis = 0 )
72+ assignment_weights = jnp .take (recv_topk_weights .reshape (- 1 ), sorted_assignment_indices , axis = 0 ).astype (
73+ recv_x .dtype
74+ )
75+ valid_sorted = jnp .arange (total_assignments , dtype = jnp .int32 ) < total_valid
76+ x_dispatch = jnp .where (valid_sorted [:, None ], x_dispatch , 0 )
77+ assignment_weights = jnp .where (valid_sorted , assignment_weights , 0 )
78+ return DeepEPLocalAssignments (x_dispatch , assignment_weights , recv_token_indices , local_group_sizes )
7879
7980
8081def _collapse_deepep_local_assignments (
@@ -85,14 +86,15 @@ def _collapse_deepep_local_assignments(
8586 recv_capacity : int ,
8687 num_recv_tokens : Int [Array , "" ],
8788) -> Float [Array , "TR D" ]:
88- recv_out = jax .ops .segment_sum (
89- out_dispatch * assignment_weights [:, None ],
90- recv_token_indices ,
91- num_segments = recv_capacity ,
92- indices_are_sorted = False ,
93- )
94- recv_valid = jnp .arange (recv_capacity , dtype = jnp .int32 ) < num_recv_tokens
95- return jnp .where (recv_valid [:, None ], recv_out , 0 )
89+ with jax .named_scope ("deepep_collapse_local_assignments" ):
90+ recv_out = jax .ops .segment_sum (
91+ out_dispatch * assignment_weights [:, None ],
92+ recv_token_indices ,
93+ num_segments = recv_capacity ,
94+ indices_are_sorted = False ,
95+ )
96+ recv_valid = jnp .arange (recv_capacity , dtype = jnp .int32 ) < num_recv_tokens
97+ return jnp .where (recv_valid [:, None ], recv_out , 0 )
9698
9799
98100def _moe_mlp_ep_deepep_local (
@@ -120,32 +122,34 @@ def _moe_mlp_ep_deepep_local(
120122 max_recv_tokens = x_local .shape [0 ] * ep_size
121123
122124 with jax .named_scope ("dispatch" ):
123- num_tokens_per_rank , num_tokens_per_expert , is_token_in_rank = deepep_get_dispatch_layout (
124- selected_experts_local ,
125- num_ranks = ep_size ,
126- num_experts = num_experts ,
127- )
128- (
129- recv_x ,
130- recv_topk_idx ,
131- recv_topk_weights ,
132- recv_src_idx ,
133- rank_prefix_matrix ,
134- channel_prefix_matrix ,
135- recv_channel_prefix_matrix ,
136- send_head ,
137- _local_expert_counts ,
138- num_recv_tokens ,
139- ) = deepep_dispatch_intranode (
140- x_local ,
141- selected_experts_local ,
142- combine_weights_local ,
143- num_tokens_per_rank ,
144- num_tokens_per_expert ,
145- is_token_in_rank ,
146- num_experts = num_experts ,
147- max_recv_tokens = max_recv_tokens ,
148- )
125+ with jax .named_scope ("deepep_layout" ):
126+ num_tokens_per_rank , num_tokens_per_expert , is_token_in_rank = deepep_get_dispatch_layout (
127+ selected_experts_local ,
128+ num_ranks = ep_size ,
129+ num_experts = num_experts ,
130+ )
131+ with jax .named_scope ("deepep_dispatch_transport" ):
132+ (
133+ recv_x ,
134+ recv_topk_idx ,
135+ recv_topk_weights ,
136+ recv_src_idx ,
137+ rank_prefix_matrix ,
138+ channel_prefix_matrix ,
139+ recv_channel_prefix_matrix ,
140+ send_head ,
141+ _local_expert_counts ,
142+ num_recv_tokens ,
143+ ) = deepep_dispatch_intranode (
144+ x_local ,
145+ selected_experts_local ,
146+ combine_weights_local ,
147+ num_tokens_per_rank ,
148+ num_tokens_per_expert ,
149+ is_token_in_rank ,
150+ num_experts = num_experts ,
151+ max_recv_tokens = max_recv_tokens ,
152+ )
149153 num_recv_tokens_scalar = jnp .squeeze (num_recv_tokens , axis = 0 )
150154 local_assignments = _pack_deepep_local_assignments (
151155 recv_x ,
@@ -175,16 +179,17 @@ def _moe_mlp_ep_deepep_local(
175179 recv_capacity = recv_x .shape [0 ],
176180 num_recv_tokens = num_recv_tokens_scalar ,
177181 )
178- out_local , _ = deepep_combine_intranode (
179- recv_out ,
180- recv_topk_weights ,
181- recv_src_idx ,
182- rank_prefix_matrix ,
183- channel_prefix_matrix ,
184- recv_channel_prefix_matrix ,
185- send_head ,
186- num_recv_tokens ,
187- is_token_in_rank ,
188- )
182+ with jax .named_scope ("deepep_combine_transport" ):
183+ out_local , _ = deepep_combine_intranode (
184+ recv_out ,
185+ recv_topk_weights ,
186+ recv_src_idx ,
187+ rank_prefix_matrix ,
188+ channel_prefix_matrix ,
189+ recv_channel_prefix_matrix ,
190+ send_head ,
191+ num_recv_tokens ,
192+ is_token_in_rank ,
193+ )
189194 dropped_total = jnp .array (0 , dtype = jnp .int32 )
190195 return out_local .astype (x_local .dtype ), dropped_total
0 commit comments