|
30 | 30 | from flashinfer import ( |
31 | 31 | fp4_quantize, |
32 | 32 | mxfp8_quantize, |
33 | | - next_positive_power_of_2, |
34 | 33 | reorder_rows_for_gated_act_gemm, |
35 | 34 | shuffle_matrix_a, |
36 | 35 | shuffle_matrix_sf_a, |
@@ -188,30 +187,6 @@ def reference_moe( |
188 | 187 | return t.to(torch.bfloat16) |
189 | 188 |
|
190 | 189 |
|
191 | | -def get_tile_tokens_dim(x: torch.Tensor, top_k: int, num_experts: int): |
192 | | - # Number of tokens in the input tensor. |
193 | | - num_tokens = x.shape[0] |
194 | | - # Factor to account for the imbalance of the experts. |
195 | | - # factor equals to the |
196 | | - # max_real_num_tokens_per_expert / perfect_num_tokens_per_expert |
197 | | - # - 1.0 means perfect expert distribution. |
198 | | - # - > 1.0 means some experts have more |
199 | | - # tokens than the perfect distribution. |
200 | | - # - < 1.0 does not make sense. |
201 | | - imbalance_factor = 1.3 |
202 | | - # Calculate the number of tokens per expert |
203 | | - # assuming perfect distribution. |
204 | | - num_tokens_per_expert = (num_tokens * top_k) // num_experts |
205 | | - # Apply the imbalance factor. |
206 | | - num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor) |
207 | | - # And pad the number to the next power of 2. |
208 | | - tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert) |
209 | | - # Cap to 8-64 tokens per CTA tile |
210 | | - # as it's the range supported by the kernel. |
211 | | - tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) |
212 | | - return tile_tokens_dim |
213 | | - |
214 | | - |
215 | 190 | def tg_mxfp4_moe( |
216 | 191 | router_logits, |
217 | 192 | topk, |
@@ -460,7 +435,6 @@ def tg_mxfp4_moe( |
460 | 435 | local_expert_offset=0, |
461 | 436 | local_num_experts=num_experts, |
462 | 437 | routed_scaling_factor=None, |
463 | | - tile_tokens_dim=get_tile_tokens_dim(hidden_states, topk, num_experts), |
464 | 438 | routing_method_type=1, # renormalize |
465 | 439 | do_finalize=True, |
466 | 440 | )[0] |
|
0 commit comments