Skip to content

Commit 7d40100

Browse files
committed
Fix Grug shared-expert flop accounting
1 parent bd70820 commit 7d40100

2 files changed

Lines changed: 6 additions & 1 deletion

File tree

experiments/grug/moe/train.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ def _compute_flops(
185185
flops_per_token = lm_flops_per_token(
186186
hidden_dim=model_config.hidden_dim,
187187
intermediate_dim=model_config.intermediate_dim,
188+
shared_intermediate_dim=model_config.shared_expert_intermediate_dim,
188189
num_layers=model_config.num_layers,
189190
num_kv_heads=model_config.num_kv_heads,
190191
num_heads=model_config.num_heads,

lib/levanter/src/levanter/utils/flop_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,13 @@ def lm_flops_per_token(
1414
num_experts: int = 1,
1515
num_shared_experts: int = 0,
1616
num_experts_per_tok: int = 1,
17+
shared_intermediate_dim: int | None = None,
1718
):
1819
head_dim = hidden_dim / num_heads
19-
mlp = 2 * (3 if glu else 2) * hidden_dim * intermediate_dim * (num_experts_per_tok + num_shared_experts)
20+
shared_intermediate_dim = intermediate_dim if shared_intermediate_dim is None else shared_intermediate_dim
21+
routed_mlp = 2 * (3 if glu else 2) * hidden_dim * intermediate_dim * num_experts_per_tok
22+
shared_mlp = 2 * (3 if glu else 2) * hidden_dim * shared_intermediate_dim * num_shared_experts
23+
mlp = routed_mlp + shared_mlp
2024
if num_experts > 1:
2125
mlp += 2 * hidden_dim * num_experts # router layer
2226
qkv_proj = 2 * hidden_dim * (num_heads * head_dim + 2 * num_kv_heads * head_dim)

0 commit comments

Comments
 (0)