-
Notifications
You must be signed in to change notification settings - Fork 2.9k
[perf]feat: Add MFU for Qwen3-VL dense #4753
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -118,6 +118,99 @@ def _estimate_qwen2_flops(config, tokens_sum, batch_seqlens, delta_time): | |||||||||||
| return flops_achieved | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def _estimate_qwen3_vl_flops(config,tokens_sum, batch_seqlens, delta_time, **kargs): | ||||||||||||
| # qwen3_vl uses text_config and vision_config to distinguish configs of different parts. | ||||||||||||
| hidden_size = config.text_config.hidden_size | ||||||||||||
| vocab_size = config.text_config.vocab_size | ||||||||||||
| num_hidden_layers = config.text_config.num_hidden_layers | ||||||||||||
| num_key_value_heads = config.text_config.num_key_value_heads | ||||||||||||
| num_attention_heads = config.text_config.num_attention_heads | ||||||||||||
| intermediate_size = config.text_config.intermediate_size | ||||||||||||
|
|
||||||||||||
| head_dim = hidden_size // num_attention_heads | ||||||||||||
| q_size = num_attention_heads * head_dim | ||||||||||||
| k_size = num_key_value_heads * head_dim | ||||||||||||
| v_size = num_key_value_heads * head_dim | ||||||||||||
|
|
||||||||||||
| # non-attn per layer parm | ||||||||||||
| mlp_N = hidden_size * intermediate_size * 3 | ||||||||||||
| attn_linear_N = hidden_size * (q_size + k_size + v_size + num_attention_heads * head_dim) | ||||||||||||
| emd_and_lm_head_N = vocab_size * hidden_size * 2 | ||||||||||||
| # non-attn all_layer parm | ||||||||||||
| dense_N = (mlp_N + attn_linear_N) * num_hidden_layers + emd_and_lm_head_N | ||||||||||||
| # non-attn all_layer & all_token fwd & bwd flops | ||||||||||||
| dense_N_flops = 6 * dense_N * tokens_sum | ||||||||||||
|
|
||||||||||||
| # qwen3_vl uses deepstack to merge visual embeds and text embeds, but it has no tensor operation. | ||||||||||||
|
|
||||||||||||
| # attn all_layer & all_token fwd & bwd flops | ||||||||||||
| seqlen_square_sum = 0 | ||||||||||||
| for seqlen in batch_seqlens: | ||||||||||||
| seqlen_square_sum += seqlen * seqlen | ||||||||||||
| attn_qkv_flops = 12 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers | ||||||||||||
|
|
||||||||||||
| # vit flops | ||||||||||||
| images_seqlens = kargs.get("images_seqlens", None) | ||||||||||||
| if images_seqlens is not None: | ||||||||||||
| vit_flops = _estimate_qwen3_vit_flop(images_seqlens, config.vision_config) | ||||||||||||
| else: | ||||||||||||
| vit_flops = 0 | ||||||||||||
|
|
||||||||||||
| # all_layer & all_token fwd & bwd flops | ||||||||||||
| flops_all_token = dense_N_flops + attn_qkv_flops + vit_flops | ||||||||||||
| flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12 | ||||||||||||
| return flops_achieved | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def _estimate_qwen3_vit_flop(images_seqlens, config): | ||||||||||||
| """ | ||||||||||||
| Estimate the FLOPS of the vision encoder for Qwen2 and Qwen2.5 | ||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The docstring is inaccurate. This function estimates FLOPS for the Qwen3-VL vision encoder, not Qwen2 and Qwen2.5. Updating the docstring will prevent future confusion.
Suggested change
|
||||||||||||
| """ | ||||||||||||
|
|
||||||||||||
| if config is None: | ||||||||||||
| return 0 | ||||||||||||
| tokens_sum = sum(images_seqlens) | ||||||||||||
|
|
||||||||||||
| num_heads = config.num_heads | ||||||||||||
| depth = config.depth | ||||||||||||
|
|
||||||||||||
| dim = config.hidden_size | ||||||||||||
| mlp_hidden_dim = config.intermediate_size | ||||||||||||
| out_hidden_size = config.out_hidden_size | ||||||||||||
|
|
||||||||||||
| spatial_merge_size = config.spatial_merge_size | ||||||||||||
|
|
||||||||||||
| head_dim = dim // num_heads | ||||||||||||
|
|
||||||||||||
| # every vision token's patch_embed comes from a conv of (C, T, H, W) -> (dim,) | ||||||||||||
| patch_embed_N = dim * config.in_channels * config.temporal_patch_size * config.patch_size * config.patch_size | ||||||||||||
| # Qwen3 VL vision mlp does not use GLU, thus 2. | ||||||||||||
| mlp_N = dim * mlp_hidden_dim * 2 | ||||||||||||
| attn_linear_N = dim * (4 * dim) # qkv and output proj | ||||||||||||
| merger_N = (out_hidden_size + (dim * (spatial_merge_size**2))) * (dim * (spatial_merge_size**2)) | ||||||||||||
|
|
||||||||||||
| # Qwen3 VL uses deep stack, one merger for every deepstack layer | ||||||||||||
| deepstack_merger_N = merger_N * len(config.deepstack_visual_indexes) | ||||||||||||
| # non-attn all_layer parm | ||||||||||||
| dense_N = patch_embed_N + (mlp_N + attn_linear_N) * depth + deepstack_merger_N + merger_N | ||||||||||||
|
|
||||||||||||
| # non-attn all_layer & all_token fwd & bwd flops | ||||||||||||
| dense_N_flops = 6 * dense_N * tokens_sum | ||||||||||||
|
|
||||||||||||
| # In Qwen3 VL, full attention is used in all vision layers. | ||||||||||||
| full_attn_layer_num = depth | ||||||||||||
|
|
||||||||||||
| # full attn layer & all_token fwd & bwd flops | ||||||||||||
| seqlen_square_sum = 0 | ||||||||||||
| for seqlen in images_seqlens: | ||||||||||||
| seqlen_square_sum += seqlen * seqlen | ||||||||||||
| attn_qkv_flops = 12 * seqlen_square_sum * head_dim * num_heads * full_attn_layer_num | ||||||||||||
|
|
||||||||||||
| vit_flops = dense_N_flops + attn_qkv_flops | ||||||||||||
|
|
||||||||||||
| return vit_flops | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def _estimate_deepseek_v3_flops(config, tokens_sum, batch_seqlens, delta_time): | ||||||||||||
| hidden_size = config.hidden_size | ||||||||||||
| vocab_size = config.vocab_size | ||||||||||||
|
|
@@ -326,7 +419,7 @@ def _estimate_unknown_flops(config, tokens_sum, batch_seqlens, delta_time): | |||||||||||
| "qwen2_5_vl": _estimate_qwen2_flops, | ||||||||||||
| "qwen3": _estimate_qwen2_flops, | ||||||||||||
| "qwen3_moe": _estimate_qwen2_moe_flops, | ||||||||||||
| "qwen3_vl": _estimate_qwen2_flops, | ||||||||||||
| "qwen3_vl": _estimate_qwen3_vl_flops, | ||||||||||||
| "qwen3_vl_moe": _estimate_qwen2_moe_flops, | ||||||||||||
| "deepseek_v3": _estimate_deepseek_v3_flops, | ||||||||||||
| "minicpmv": _estimate_qwen2_flops, | ||||||||||||
|
|
@@ -357,7 +450,7 @@ def __init__(self, config: PretrainedConfig): | |||||||||||
| f"zero." | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
| self.config = getattr(config, "text_config", config) | ||||||||||||
| self.config = config | ||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This change introduces a critical regression. While it is necessary for Other VL models rely on estimation functions like To fix this, we can special-case
Suggested change
|
||||||||||||
|
|
||||||||||||
| # TODO: actually we can make this a static method | ||||||||||||
| def estimate_flops(self, batch_seqlens, delta_time): | ||||||||||||
|
|
@@ -377,4 +470,4 @@ def estimate_flops(self, batch_seqlens, delta_time): | |||||||||||
| func = ESTIMATE_FUNC.get(self.config.model_type, _estimate_unknown_flops) | ||||||||||||
| estimated_flops = func(self.config, tokens_sum, batch_seqlens, delta_time) | ||||||||||||
| promised_flops = get_device_flops() | ||||||||||||
| return estimated_flops, promised_flops | ||||||||||||
| return estimated_flops, promised_flops | ||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
**kargsargument will always be empty because the calling methodFlopsCounter.estimate_flopsdoes not accept or forward any keyword arguments. As a result,images_seqlenswill always beNone, and the ViT FLOPs calculation (vit_flops) will always be zero. This makes a significant part of this function's logic dead code and leads to incorrect FLOPs estimation forqwen3_vl.To fix this, you need to update the signature of
FlopsCounter.estimate_flopsto accept additional arguments (e.g.,**kwargs) and pass them along to the estimation function.