Skip to content

Commit 854ca34

Browse files
authored
[perf]feat: GPT-OSS mfu compute support (#4750)
1 parent 4017d63 commit 854ca34

File tree

2 files changed

+129
-1
lines changed

2 files changed

+129
-1
lines changed

tests/utils/test_flops_counter.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,64 @@ def __init__(self, config_dict):
206206
# total: 986195089686528 / 1e12 = 986.195089686528
207207
"expected_flops_tuple": (283517065887744 / 1e12, 986195089686528 / 1e12),
208208
},
209+
"gpt_oss": {
210+
"config": {
211+
"model_type": "gpt_oss",
212+
"vocab_size": 201088,
213+
"hidden_size": 2880,
214+
"num_hidden_layers": 24,
215+
"num_attention_heads": 64,
216+
"num_key_value_heads": 8,
217+
"head_dim": 64,
218+
"intermediate_size": 2880,
219+
"num_local_experts": 32,
220+
"num_experts_per_tok": 4,
221+
"sliding_window": 128,
222+
"layer_types": [
223+
"sliding_attention", "full_attention", "sliding_attention", "full_attention",
224+
"sliding_attention", "full_attention", "sliding_attention", "full_attention",
225+
"sliding_attention", "full_attention", "sliding_attention", "full_attention",
226+
"sliding_attention", "full_attention", "sliding_attention", "full_attention",
227+
"sliding_attention", "full_attention", "sliding_attention", "full_attention",
228+
"sliding_attention", "full_attention", "sliding_attention", "full_attention"
229+
],
230+
},
231+
"batch_seqlens_tuple": ([512, 1024, 2048], [4096, 4096, 4096]),
232+
# GPT-OSS has alternating sliding / full attention
233+
# Even layers (12 layers) use sliding window attention with window_size = 128
234+
# Odd layers (12 layers) use full attention
235+
#
236+
# Non-attention FLOPs:
237+
# vocab part: 201088 * 2880 * 2 = 1158266880
238+
# attn linear part per layer:
239+
# Q: 2880 * (64 * 64) = 11796480
240+
# K: 2880 * (8 * 64) = 1474560
241+
# V: 2880 * (8 * 64) = 1474560
242+
# O: (64 * 64) * 2880 = 11796480
243+
# attn linear total = 26542080
244+
# mlp (MoE, SwiGLU) part per layer:
245+
# gate: 2880 * 32 = 92160
246+
# active experts: 3 * 2880 * 2880 * 4 = 99532800
247+
# mlp total = 99624960
248+
# total per layer: 26542080 + 99624960 = 126167040
249+
# all layers:
250+
# 126167040 * 24 = 3028008960
251+
# total dense params:
252+
# 3028008960 + 1158266880 = 4186275840
253+
#
254+
# For batch [512, 1024, 2048], tokens_sum = 3584:
255+
# dense flops: 6 * 4186275840 * 3584 = 90021675663360
256+
# seqlen_square_sum: 71565312 (calculated with sliding window logic)
257+
# attn flops: 12 * 71565312 * 64 * 64 = 3517578215424
258+
# total: 93539253878784 / 1e12 = 93.539253878784
259+
#
260+
# For batch [4096, 4096, 4096], tokens_sum = 12288:
261+
# dense flops: 6 * 4186275840 * 12288 = 308646629068800
262+
# seqlen_square_sum: 622854144 (calculated with sliding window logic)
263+
# attn flops: 12 * 622854144 * 64 * 64 = 30613642948608
264+
# total: 339260272017408 / 1e12 = 339.260272017408
265+
"expected_flops_tuple": (93539253878784 / 1e12, 339260272017408 / 1e12),
266+
},
209267
"apertus": {
210268
"config": { # swiss-ai/Apertus-8B
211269
"model_type": "apertus",
@@ -229,7 +287,7 @@ def __init__(self, config_dict):
229287

230288
@pytest.mark.parametrize(
231289
"config_type",
232-
["llama", "qwen2", "qwen3", "qwen3_moe", "deepseek_v3", "mistral", "gemma3_text", "apertus"],
290+
["llama", "qwen2", "qwen3", "qwen3_moe", "deepseek_v3", "mistral", "gemma3_text", "apertus", "gpt_oss"],
233291
)
234292
def test_flops_counter(config_type: str):
235293
test_config = CONFIG[config_type]

verl/utils/flops_counter.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,75 @@ def _estimate_apertus_flops(config, tokens_sum, batch_seqlens, delta_time):
313313
flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12
314314
return flops_achieved
315315

316+
def _estimate_gpt_oss_flops(config, tokens_sum, batch_seqlens, delta_time):
317+
hidden_size = config.hidden_size
318+
vocab_size = config.vocab_size
319+
num_hidden_layers = config.num_hidden_layers
320+
num_key_value_heads = config.num_key_value_heads
321+
num_attention_heads = config.num_attention_heads
322+
323+
# MoE params
324+
moe_intermediate_size = config.intermediate_size
325+
num_experts = config.num_local_experts
326+
num_experts_per_tok = config.num_experts_per_tok
327+
mlp_matrices = 3
328+
329+
# Head dim
330+
head_dim = getattr(config, "head_dim", hidden_size // num_attention_heads)
331+
q_size = num_attention_heads * head_dim
332+
k_size = num_key_value_heads * head_dim
333+
v_size = num_key_value_heads * head_dim
334+
335+
# 1. Attention Block (GQA)
336+
attn_linear_N = hidden_size * (q_size + k_size + v_size + num_attention_heads * head_dim)
337+
# 2. MLP / MoE Block
338+
# Gate network
339+
moe_gate_N = hidden_size * num_experts
340+
# Expert forward calculation, Active parameters: mlp_matrices * H * I * num_experts_per_tok
341+
moe_expert_N = hidden_size * moe_intermediate_size * mlp_matrices * num_experts_per_tok
342+
343+
moe_mlp_N = moe_gate_N + moe_expert_N
344+
345+
emd_and_lm_head_N = vocab_size * hidden_size * 2
346+
347+
# Total non-attn params per layer * layers + embeddings
348+
# (moe_mlp_N + attn_linear_N) * layers
349+
dense_N = (moe_mlp_N + attn_linear_N) * num_hidden_layers + emd_and_lm_head_N
350+
351+
# FLOPs for dense part (fwd + bwd = 6 * N)
352+
dense_N_flops = 6 * dense_N * tokens_sum
353+
354+
# 3. Attention Matrix FLOPs
355+
seqlen_square_sum = 0
356+
357+
# Handle sliding window attention
358+
layer_types = getattr(config, "layer_types", None)
359+
sliding_window = getattr(config, "sliding_window", 128)
360+
361+
if layer_types:
362+
for layer_type in layer_types:
363+
is_sliding = layer_type == "sliding_attention"
364+
365+
for seqlen in batch_seqlens:
366+
if is_sliding and sliding_window:
367+
# Sliding window limits each token to attend to at most window_size tokens
368+
effective_seqlen = min(seqlen, sliding_window)
369+
seqlen_square_sum += seqlen * effective_seqlen
370+
else:
371+
# Full attention
372+
seqlen_square_sum += seqlen * seqlen
373+
else:
374+
# Default to full attention for all layers
375+
for seqlen in batch_seqlens:
376+
seqlen_square_sum += seqlen * seqlen
377+
seqlen_square_sum *= num_hidden_layers
378+
379+
attn_qkv_flops = 12 * seqlen_square_sum * head_dim * num_attention_heads
380+
381+
# Total FLOPs
382+
flops_all_token = dense_N_flops + attn_qkv_flops
383+
flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12
384+
return flops_achieved
316385

317386
def _estimate_unknown_flops(config, tokens_sum, batch_seqlens, delta_time):
318387
return 0
@@ -336,6 +405,7 @@ def _estimate_unknown_flops(config, tokens_sum, batch_seqlens, delta_time):
336405
"seed_oss": _estimate_qwen2_flops,
337406
"apertus": _estimate_apertus_flops,
338407
"glm4v": _estimate_qwen2_flops,
408+
"gpt_oss": _estimate_gpt_oss_flops,
339409
}
340410

341411

0 commit comments

Comments
 (0)