Skip to content

Commit 2386abd

Browse files
authored
[Feat] Align TFLOPs calculation (#28)
1 parent 1812c5d commit 2386abd

File tree

1 file changed

+206
-1
lines changed

1 file changed

+206
-1
lines changed

primus/modules/trainer/megatron/trainer.py

Lines changed: 206 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,6 @@
127127
evaluate_and_print_results,
128128
get_model,
129129
get_optimizer_param_scheduler,
130-
num_floating_point_operations,
131130
post_training_step_callbacks,
132131
preprocess_common_state_dict,
133132
print_datetime,
@@ -161,6 +160,212 @@
161160

162161
from .utils import set_wandb_writer_patch
163162

163+
164+
def num_floating_point_operations(args, batch_size):
165+
166+
def calculate_layer_counts():
167+
"""Calculate the number of attention, Mamba, and MLP layers."""
168+
if args.hybrid_override_pattern:
169+
counts = {"M": 0, "*": 0, "-": 0}
170+
for layer_type in args.hybrid_override_pattern:
171+
if layer_type in counts:
172+
counts[layer_type] += 1
173+
return counts["*"], counts["M"], counts["-"]
174+
else:
175+
num_attn_layers = round(args.num_layers * args.hybrid_attention_ratio)
176+
num_mlp_layers = round(args.num_layers * args.hybrid_mlp_ratio)
177+
num_mamba_layers = args.num_layers - num_attn_layers - num_mlp_layers
178+
return num_attn_layers, num_mamba_layers, num_mlp_layers
179+
180+
def mlp_layer_flops(batch_size, seq_len, hidden_size, expansion=4.0, swiglu=False):
181+
"""Calculate FLOPs for an MLP layer."""
182+
scale_factor = 3.0 / 2.0 if swiglu else 1.0
183+
return 4 * expansion * scale_factor * batch_size * seq_len * hidden_size**2
184+
185+
def attn_layer_flops(
186+
batch_size, seq_len, hidden_size, num_heads, gqa=True, gqa_groups=8, kv_channels=None
187+
):
188+
"""Calculate FLOPs for an attention layer."""
189+
p = (kv_channels * num_heads / hidden_size) if kv_channels else 1
190+
g = gqa_groups if gqa else num_heads
191+
return (
192+
4
193+
* batch_size
194+
* seq_len
195+
* hidden_size
196+
* p
197+
* (hidden_size + (hidden_size * (g / num_heads)) + (seq_len / 2))
198+
)
199+
200+
def mamba_layer_flops(batch_size, seq_len, hidden_size, state_dim=16, head_dim=64, num_groups=1):
201+
"""Calculate FLOPs for a Mamba layer."""
202+
# Note (rwaleffe): flops estimate for scan should be updated based on new SSD kernels,
203+
# but small percent of overall layer flops
204+
d_in = 2 * hidden_size
205+
nheads = d_in // head_dim
206+
return (
207+
(
208+
2 * batch_size * seq_len * hidden_size * (2 * d_in + 2 * num_groups * state_dim + nheads)
209+
) # in_proj
210+
+ (7 * batch_size * seq_len * d_in * state_dim) # scan
211+
+ (2 * batch_size * seq_len * d_in * hidden_size) # out_proj
212+
)
213+
214+
def hybrid_flops(
215+
batch_size,
216+
seq_len,
217+
hidden_size,
218+
num_attn_layers,
219+
num_mamba_layers,
220+
num_mlp_layers,
221+
mamba_state_dim=128,
222+
mamba_head_dim=64,
223+
mamba_num_groups=8,
224+
num_attn_heads=32,
225+
gqa=True,
226+
gqa_groups=8,
227+
kv_channels=None,
228+
mlp_expansion=4.0,
229+
swiglu=False,
230+
vocab_size=256000,
231+
):
232+
"""Calculate total FLOPs for the hybrid model."""
233+
flops_fwd = (
234+
num_attn_layers
235+
* attn_layer_flops(batch_size, seq_len, hidden_size, num_attn_heads, gqa, gqa_groups, kv_channels)
236+
+ num_mlp_layers * mlp_layer_flops(batch_size, seq_len, hidden_size, mlp_expansion, swiglu)
237+
+ num_mamba_layers
238+
* mamba_layer_flops(
239+
batch_size, seq_len, hidden_size, mamba_state_dim, mamba_head_dim, mamba_num_groups
240+
)
241+
+ (2 * batch_size * seq_len * hidden_size * vocab_size) # logits computation
242+
)
243+
return flops_fwd * 3
244+
245+
def transformer_flops():
246+
"""Calculate FLOPs for a standard Transformer model."""
247+
# TODO(helenn/dnarayanan): Refactor this to reuse the helper methods.
248+
# Attention projection size.
249+
query_projection_size = args.kv_channels * args.num_attention_heads
250+
query_projection_to_hidden_size_ratio = query_projection_size / args.hidden_size
251+
# Group Query Attention.
252+
if not args.group_query_attention:
253+
args.num_query_groups = args.num_attention_heads
254+
# MoE.
255+
if args.num_experts is None:
256+
# Every Transformer MLP is dense.
257+
num_dense_layers = args.num_layers
258+
num_moe_layers = 0
259+
num_experts_routed_to = 0
260+
else:
261+
# Calculate number of dense and MoE Transformer MLPs.
262+
if isinstance(args.moe_layer_freq, int):
263+
moe_layer_pattern = [
264+
1 if (i % args.moe_layer_freq == 0) else 0 for i in range(args.num_layers)
265+
]
266+
elif isinstance(args.moe_layer_freq, list):
267+
moe_layer_pattern = args.moe_layer_freq
268+
else:
269+
raise RuntimeError("Illegal --moe-layer-freq argument provided!")
270+
assert len(moe_layer_pattern) == args.num_layers
271+
num_moe_layers = sum(moe_layer_pattern) # Number of 1s in `moe_layer_pattern`.
272+
num_dense_layers = args.num_layers - num_moe_layers
273+
num_experts_routed_to = args.moe_router_topk
274+
275+
moe_ffn_hidden_size = (
276+
args.moe_ffn_hidden_size if args.moe_ffn_hidden_size is not None else args.ffn_hidden_size
277+
)
278+
shared_expert_ffn_hidden_size = (
279+
0
280+
if args.moe_shared_expert_intermediate_size is None
281+
else args.moe_shared_expert_intermediate_size
282+
)
283+
# SwiGLU.
284+
gated_linear_multiplier = 3 / 2 if args.swiglu else 1
285+
286+
# The 12x term below comes from the following factors; for more details, see
287+
# "APPENDIX: FLOATING-POINT OPERATIONS" in https://arxiv.org/abs/2104.04473.
288+
# - 3x: Each GEMM in the model needs to be performed 3 times (forward pass,
289+
# backward wgrad [weight gradient], backward dgrad [data gradient]).
290+
# - 2x: GEMMs of a particular size are stacked twice in the standard Transformer model
291+
# architectures implemented in this codebase (e.g., h->ffn_h GEMM and ffn_h->h GEMM
292+
# in MLP layer).
293+
# - 2x: A GEMM of a m*n tensor with a n*k tensor requires 2mnk floating-point operations.
294+
expansion_factor = 3 * 2 * 2
295+
296+
return (
297+
expansion_factor
298+
* batch_size
299+
* args.seq_length
300+
* args.num_layers
301+
* args.hidden_size
302+
* args.hidden_size
303+
* (
304+
# Attention.
305+
(
306+
(
307+
1
308+
+ (args.num_query_groups / args.num_attention_heads)
309+
# Only half of the attention matrix is non-zero and needs to be multiplied with V.
310+
+ (args.seq_length / args.hidden_size)
311+
)
312+
* query_projection_to_hidden_size_ratio
313+
)
314+
# MLP.
315+
+ (
316+
(
317+
# Dense.
318+
(args.ffn_hidden_size * num_dense_layers)
319+
+
320+
# MoE.
321+
(
322+
(
323+
# Routed experts.
324+
moe_ffn_hidden_size * num_experts_routed_to
325+
+
326+
# Shared experts.
327+
shared_expert_ffn_hidden_size
328+
)
329+
* num_moe_layers
330+
)
331+
)
332+
* gated_linear_multiplier
333+
/ (args.num_layers * args.hidden_size)
334+
)
335+
# Logit.
336+
+ (args.padded_vocab_size / (2 * args.num_layers * args.hidden_size))
337+
)
338+
)
339+
340+
# Main entrypoint for FLOPs calculation.
341+
if args.is_hybrid_model:
342+
# Calculate the number of each type of layer.
343+
num_attn_layers, num_mamba_layers, num_mlp_layers = calculate_layer_counts()
344+
345+
# Compute hybrid model FLOPs.
346+
return hybrid_flops(
347+
batch_size=batch_size,
348+
seq_len=args.seq_length,
349+
hidden_size=args.hidden_size,
350+
num_attn_layers=num_attn_layers,
351+
num_mamba_layers=num_mamba_layers,
352+
num_mlp_layers=num_mlp_layers,
353+
mamba_state_dim=args.mamba_state_dim,
354+
mamba_head_dim=args.mamba_head_dim,
355+
mamba_num_groups=args.mamba_num_groups,
356+
num_attn_heads=args.num_attention_heads,
357+
gqa=args.group_query_attention,
358+
gqa_groups=args.num_query_groups,
359+
kv_channels=args.kv_channels,
360+
mlp_expansion=args.ffn_hidden_size / args.hidden_size,
361+
swiglu=args.swiglu,
362+
vocab_size=args.padded_vocab_size,
363+
)
364+
else:
365+
# Compute standard Transformer model FLOPs.
366+
return transformer_flops()
367+
368+
164369
# The earliest we can measure the start time.
165370
_TRAIN_START_TIME = time.time()
166371

0 commit comments

Comments
 (0)