@@ -206,6 +206,64 @@ def __init__(self, config_dict):
206206 # total: 986195089686528 / 1e12 = 986.195089686528
207207 "expected_flops_tuple" : (283517065887744 / 1e12 , 986195089686528 / 1e12 ),
208208 },
209+ "gpt_oss" : {
210+ "config" : {
211+ "model_type" : "gpt_oss" ,
212+ "vocab_size" : 201088 ,
213+ "hidden_size" : 2880 ,
214+ "num_hidden_layers" : 24 ,
215+ "num_attention_heads" : 64 ,
216+ "num_key_value_heads" : 8 ,
217+ "head_dim" : 64 ,
218+ "intermediate_size" : 2880 ,
219+ "num_local_experts" : 32 ,
220+ "num_experts_per_tok" : 4 ,
221+ "sliding_window" : 128 ,
222+ "layer_types" : [
223+ "sliding_attention" , "full_attention" , "sliding_attention" , "full_attention" ,
224+ "sliding_attention" , "full_attention" , "sliding_attention" , "full_attention" ,
225+ "sliding_attention" , "full_attention" , "sliding_attention" , "full_attention" ,
226+ "sliding_attention" , "full_attention" , "sliding_attention" , "full_attention" ,
227+ "sliding_attention" , "full_attention" , "sliding_attention" , "full_attention" ,
228+ "sliding_attention" , "full_attention" , "sliding_attention" , "full_attention"
229+ ],
230+ },
231+ "batch_seqlens_tuple" : ([512 , 1024 , 2048 ], [4096 , 4096 , 4096 ]),
232+ # GPT-OSS has alternating sliding / full attention
233+ # Even layers (12 layers) use sliding window attention with window_size = 128
234+ # Odd layers (12 layers) use full attention
235+ #
236+ # Non-attention FLOPs:
237+ # vocab part: 201088 * 2880 * 2 = 1158266880
238+ # attn linear part per layer:
239+ # Q: 2880 * (64 * 64) = 11796480
240+ # K: 2880 * (8 * 64) = 1474560
241+ # V: 2880 * (8 * 64) = 1474560
242+ # O: (64 * 64) * 2880 = 11796480
243+ # attn linear total = 26542080
244+ # mlp (MoE, SwiGLU) part per layer:
245+ # gate: 2880 * 32 = 92160
246+ # active experts: 3 * 2880 * 2880 * 4 = 99532800
247+ # mlp total = 99624960
248+ # total per layer: 26542080 + 99624960 = 126167040
249+ # all layers:
250+ # 126167040 * 24 = 3028008960
251+ # total dense params:
252+ # 3028008960 + 1158266880 = 4186275840
253+ #
254+ # For batch [512, 1024, 2048], tokens_sum = 3584:
255+ # dense flops: 6 * 4186275840 * 3584 = 90021675663360
256+ # seqlen_square_sum: 71565312 (calculated with sliding window logic)
257+ # attn flops: 12 * 71565312 * 64 * 64 = 3517578215424
258+ # total: 93539253878784 / 1e12 = 93.539253878784
259+ #
260+ # For batch [4096, 4096, 4096], tokens_sum = 12288:
261+ # dense flops: 6 * 4186275840 * 12288 = 308646629068800
262+ # seqlen_square_sum: 622854144 (calculated with sliding window logic)
263+ # attn flops: 12 * 622854144 * 64 * 64 = 30613642948608
264+ # total: 339260272017408 / 1e12 = 339.260272017408
265+ "expected_flops_tuple" : (93539253878784 / 1e12 , 339260272017408 / 1e12 ),
266+ },
209267 "apertus" : {
210268 "config" : { # swiss-ai/Apertus-8B
211269 "model_type" : "apertus" ,
@@ -229,7 +287,7 @@ def __init__(self, config_dict):
229287
230288@pytest .mark .parametrize (
231289 "config_type" ,
232- ["llama" , "qwen2" , "qwen3" , "qwen3_moe" , "deepseek_v3" , "mistral" , "gemma3_text" , "apertus" ],
290+ ["llama" , "qwen2" , "qwen3" , "qwen3_moe" , "deepseek_v3" , "mistral" , "gemma3_text" , "apertus" , "gpt_oss" ],
233291)
234292def test_flops_counter (config_type : str ):
235293 test_config = CONFIG [config_type ]
0 commit comments