Skip to content

Conversation

@mikequan0425
Copy link

@mikequan0425 mikequan0425 commented Dec 31, 2025

What does this PR do?

Add the _estimate_gpt_oss_flops function to calculate the FLOPs of GPT OSS models, which supports the computation for standard attention, sliding window attention, and MoE layers.Update the test cases to verify the calculation accuracy of GPT OSS models.

Checklist Before Starting

  • Search for similar PRs. Paste at least one query link here: [perf] feat: mistral and gemma3_text mfu compute support #2622, [misc] add support for qwen3 model (dense/moe) #1409
  • Format the PR title as [{modules}] {type}: {description} (This will be checked by the CI)
    • {modules} include fsdp, megatron, sglang, vllm, rollout, trainer, ci, training_utils, recipe, hardware, deployment, ray, worker, single_controller, misc, perf, model, algo, env, tool, ckpt, doc, data, cfg, reward
    • If this PR involves multiple modules, separate them with , like [megatron, fsdp, doc]
    • {type} is in feat, fix, refactor, chore, test
    • If this PR breaks any API (CLI arguments, config, function signature, etc.), add [BREAKING] to the beginning of the title.
    • Example: [BREAKING][fsdp, megatron] feat: dynamic batching

Test

The following is the output result of running the test file.
image

Checklist Before Submitting

Important

Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.

Add the _estimate_gpt_oss_flops function to calculate the FLOPs of GPT OSS models, which supports the computation for standard attention, sliding window attention, and MoE layers.Update the test cases to verify the calculation accuracy of GPT OSS models.
…T-OSS model

Revise comments and configuration parameters to more accurately reflect the FLOPs calculation of the GPT-OSS model
Update configurations and comments in test cases to match the implementation logic
@CLAassistant
Copy link

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you sign our Contributor License Agreement before we can accept your contribution.
You have signed the CLA already but the status is still pending? Let us recheck it.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for calculating MFU for GPT-OSS models by introducing the _estimate_gpt_oss_flops function and corresponding test cases. The implementation for the new model appears correct. However, I found an issue in the updated test data for the existing gemma3_text model, where the seqlen_square_sum and subsequent FLOPs calculations seem incorrect for one of the test cases. I've provided a detailed comment and a suggestion to fix it. Please review this finding.

Comment on lines +203 to +206
# seqlen_square_sum: 1373634560 (calculated with sliding window logic)
# attn flops: 12 * 1373634560 * 256 * 16 = 67515029389312
# total: 1009005293453312 / 1e12 = 1009.005293453312
"expected_flops_tuple": (283517065887744 / 1e12, 1009005293453312 / 1e12),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There seems to be a miscalculation in the updated test values for gemma3_text for the [4096, 4096, 4096] batch case.

While the dense_flops value was correctly fixed, the seqlen_square_sum appears incorrect. My calculation shows it should be 905,969,664, which matches the original value in the file, not the new value of 1,373,634,560.

Here's a breakdown of the calculation for seqlen_square_sum:

  • The model has 48 layers, with a sliding window pattern of 6, resulting in 8 full-attention layers and 40 sliding-window layers.
  • The sliding window size is 1024.
  • For 40 sliding layers: 40 * (3 * 4096 * 1024) = 503,316,480
  • For 8 full layers: 8 * (3 * 4096 * 4096) = 402,653,184
  • Total seqlen_square_sum: 503,316,480 + 402,653,184 = 905,969,664.

This error propagates to attn_flops, total, and expected_flops_tuple. The suggested change below corrects these values.

Suggested change
# seqlen_square_sum: 1373634560 (calculated with sliding window logic)
# attn flops: 12 * 1373634560 * 256 * 16 = 67515029389312
# total: 1009005293453312 / 1e12 = 1009.005293453312
"expected_flops_tuple": (283517065887744 / 1e12, 1009005293453312 / 1e12),
# seqlen_square_sum: 905969664 (calculated with sliding window logic)
# attn flops: 12 * 905969664 * 256 * 16 = 44530220924928
# total: 941490264064000 + 44530220924928 = 986020485000000
"expected_flops_tuple": (283517065887744 / 1e12, 986020485000000 / 1e12),

@tardis-key tardis-key mentioned this pull request Dec 31, 2025
4 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants