File tree Expand file tree Collapse file tree
lib/levanter/src/levanter/utils Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -185,6 +185,7 @@ def _compute_flops(
185185 flops_per_token = lm_flops_per_token (
186186 hidden_dim = model_config .hidden_dim ,
187187 intermediate_dim = model_config .intermediate_dim ,
188+ shared_intermediate_dim = model_config .shared_expert_intermediate_dim ,
188189 num_layers = model_config .num_layers ,
189190 num_kv_heads = model_config .num_kv_heads ,
190191 num_heads = model_config .num_heads ,
Original file line number Diff line number Diff line change @@ -14,9 +14,13 @@ def lm_flops_per_token(
1414 num_experts : int = 1 ,
1515 num_shared_experts : int = 0 ,
1616 num_experts_per_tok : int = 1 ,
17+ shared_intermediate_dim : int | None = None ,
1718):
1819 head_dim = hidden_dim / num_heads
19- mlp = 2 * (3 if glu else 2 ) * hidden_dim * intermediate_dim * (num_experts_per_tok + num_shared_experts )
20+ shared_intermediate_dim = intermediate_dim if shared_intermediate_dim is None else shared_intermediate_dim
21+ routed_mlp = 2 * (3 if glu else 2 ) * hidden_dim * intermediate_dim * num_experts_per_tok
22+ shared_mlp = 2 * (3 if glu else 2 ) * hidden_dim * shared_intermediate_dim * num_shared_experts
23+ mlp = routed_mlp + shared_mlp
2024 if num_experts > 1 :
2125 mlp += 2 * hidden_dim * num_experts # router layer
2226 qkv_proj = 2 * hidden_dim * (num_heads * head_dim + 2 * num_kv_heads * head_dim )
You can’t perform that action at this time.
0 commit comments