|
127 | 127 | evaluate_and_print_results, |
128 | 128 | get_model, |
129 | 129 | get_optimizer_param_scheduler, |
130 | | - num_floating_point_operations, |
131 | 130 | post_training_step_callbacks, |
132 | 131 | preprocess_common_state_dict, |
133 | 132 | print_datetime, |
|
161 | 160 |
|
162 | 161 | from .utils import set_wandb_writer_patch |
163 | 162 |
|
| 163 | + |
| 164 | +def num_floating_point_operations(args, batch_size): |
| 165 | + |
| 166 | + def calculate_layer_counts(): |
| 167 | + """Calculate the number of attention, Mamba, and MLP layers.""" |
| 168 | + if args.hybrid_override_pattern: |
| 169 | + counts = {"M": 0, "*": 0, "-": 0} |
| 170 | + for layer_type in args.hybrid_override_pattern: |
| 171 | + if layer_type in counts: |
| 172 | + counts[layer_type] += 1 |
| 173 | + return counts["*"], counts["M"], counts["-"] |
| 174 | + else: |
| 175 | + num_attn_layers = round(args.num_layers * args.hybrid_attention_ratio) |
| 176 | + num_mlp_layers = round(args.num_layers * args.hybrid_mlp_ratio) |
| 177 | + num_mamba_layers = args.num_layers - num_attn_layers - num_mlp_layers |
| 178 | + return num_attn_layers, num_mamba_layers, num_mlp_layers |
| 179 | + |
| 180 | + def mlp_layer_flops(batch_size, seq_len, hidden_size, expansion=4.0, swiglu=False): |
| 181 | + """Calculate FLOPs for an MLP layer.""" |
| 182 | + scale_factor = 3.0 / 2.0 if swiglu else 1.0 |
| 183 | + return 4 * expansion * scale_factor * batch_size * seq_len * hidden_size**2 |
| 184 | + |
| 185 | + def attn_layer_flops( |
| 186 | + batch_size, seq_len, hidden_size, num_heads, gqa=True, gqa_groups=8, kv_channels=None |
| 187 | + ): |
| 188 | + """Calculate FLOPs for an attention layer.""" |
| 189 | + p = (kv_channels * num_heads / hidden_size) if kv_channels else 1 |
| 190 | + g = gqa_groups if gqa else num_heads |
| 191 | + return ( |
| 192 | + 4 |
| 193 | + * batch_size |
| 194 | + * seq_len |
| 195 | + * hidden_size |
| 196 | + * p |
| 197 | + * (hidden_size + (hidden_size * (g / num_heads)) + (seq_len / 2)) |
| 198 | + ) |
| 199 | + |
| 200 | + def mamba_layer_flops(batch_size, seq_len, hidden_size, state_dim=16, head_dim=64, num_groups=1): |
| 201 | + """Calculate FLOPs for a Mamba layer.""" |
| 202 | + # Note (rwaleffe): flops estimate for scan should be updated based on new SSD kernels, |
| 203 | + # but small percent of overall layer flops |
| 204 | + d_in = 2 * hidden_size |
| 205 | + nheads = d_in // head_dim |
| 206 | + return ( |
| 207 | + ( |
| 208 | + 2 * batch_size * seq_len * hidden_size * (2 * d_in + 2 * num_groups * state_dim + nheads) |
| 209 | + ) # in_proj |
| 210 | + + (7 * batch_size * seq_len * d_in * state_dim) # scan |
| 211 | + + (2 * batch_size * seq_len * d_in * hidden_size) # out_proj |
| 212 | + ) |
| 213 | + |
| 214 | + def hybrid_flops( |
| 215 | + batch_size, |
| 216 | + seq_len, |
| 217 | + hidden_size, |
| 218 | + num_attn_layers, |
| 219 | + num_mamba_layers, |
| 220 | + num_mlp_layers, |
| 221 | + mamba_state_dim=128, |
| 222 | + mamba_head_dim=64, |
| 223 | + mamba_num_groups=8, |
| 224 | + num_attn_heads=32, |
| 225 | + gqa=True, |
| 226 | + gqa_groups=8, |
| 227 | + kv_channels=None, |
| 228 | + mlp_expansion=4.0, |
| 229 | + swiglu=False, |
| 230 | + vocab_size=256000, |
| 231 | + ): |
| 232 | + """Calculate total FLOPs for the hybrid model.""" |
| 233 | + flops_fwd = ( |
| 234 | + num_attn_layers |
| 235 | + * attn_layer_flops(batch_size, seq_len, hidden_size, num_attn_heads, gqa, gqa_groups, kv_channels) |
| 236 | + + num_mlp_layers * mlp_layer_flops(batch_size, seq_len, hidden_size, mlp_expansion, swiglu) |
| 237 | + + num_mamba_layers |
| 238 | + * mamba_layer_flops( |
| 239 | + batch_size, seq_len, hidden_size, mamba_state_dim, mamba_head_dim, mamba_num_groups |
| 240 | + ) |
| 241 | + + (2 * batch_size * seq_len * hidden_size * vocab_size) # logits computation |
| 242 | + ) |
| 243 | + return flops_fwd * 3 |
| 244 | + |
| 245 | + def transformer_flops(): |
| 246 | + """Calculate FLOPs for a standard Transformer model.""" |
| 247 | + # TODO(helenn/dnarayanan): Refactor this to reuse the helper methods. |
| 248 | + # Attention projection size. |
| 249 | + query_projection_size = args.kv_channels * args.num_attention_heads |
| 250 | + query_projection_to_hidden_size_ratio = query_projection_size / args.hidden_size |
| 251 | + # Group Query Attention. |
| 252 | + if not args.group_query_attention: |
| 253 | + args.num_query_groups = args.num_attention_heads |
| 254 | + # MoE. |
| 255 | + if args.num_experts is None: |
| 256 | + # Every Transformer MLP is dense. |
| 257 | + num_dense_layers = args.num_layers |
| 258 | + num_moe_layers = 0 |
| 259 | + num_experts_routed_to = 0 |
| 260 | + else: |
| 261 | + # Calculate number of dense and MoE Transformer MLPs. |
| 262 | + if isinstance(args.moe_layer_freq, int): |
| 263 | + moe_layer_pattern = [ |
| 264 | + 1 if (i % args.moe_layer_freq == 0) else 0 for i in range(args.num_layers) |
| 265 | + ] |
| 266 | + elif isinstance(args.moe_layer_freq, list): |
| 267 | + moe_layer_pattern = args.moe_layer_freq |
| 268 | + else: |
| 269 | + raise RuntimeError("Illegal --moe-layer-freq argument provided!") |
| 270 | + assert len(moe_layer_pattern) == args.num_layers |
| 271 | + num_moe_layers = sum(moe_layer_pattern) # Number of 1s in `moe_layer_pattern`. |
| 272 | + num_dense_layers = args.num_layers - num_moe_layers |
| 273 | + num_experts_routed_to = args.moe_router_topk |
| 274 | + |
| 275 | + moe_ffn_hidden_size = ( |
| 276 | + args.moe_ffn_hidden_size if args.moe_ffn_hidden_size is not None else args.ffn_hidden_size |
| 277 | + ) |
| 278 | + shared_expert_ffn_hidden_size = ( |
| 279 | + 0 |
| 280 | + if args.moe_shared_expert_intermediate_size is None |
| 281 | + else args.moe_shared_expert_intermediate_size |
| 282 | + ) |
| 283 | + # SwiGLU. |
| 284 | + gated_linear_multiplier = 3 / 2 if args.swiglu else 1 |
| 285 | + |
| 286 | + # The 12x term below comes from the following factors; for more details, see |
| 287 | + # "APPENDIX: FLOATING-POINT OPERATIONS" in https://arxiv.org/abs/2104.04473. |
| 288 | + # - 3x: Each GEMM in the model needs to be performed 3 times (forward pass, |
| 289 | + # backward wgrad [weight gradient], backward dgrad [data gradient]). |
| 290 | + # - 2x: GEMMs of a particular size are stacked twice in the standard Transformer model |
| 291 | + # architectures implemented in this codebase (e.g., h->ffn_h GEMM and ffn_h->h GEMM |
| 292 | + # in MLP layer). |
| 293 | + # - 2x: A GEMM of a m*n tensor with a n*k tensor requires 2mnk floating-point operations. |
| 294 | + expansion_factor = 3 * 2 * 2 |
| 295 | + |
| 296 | + return ( |
| 297 | + expansion_factor |
| 298 | + * batch_size |
| 299 | + * args.seq_length |
| 300 | + * args.num_layers |
| 301 | + * args.hidden_size |
| 302 | + * args.hidden_size |
| 303 | + * ( |
| 304 | + # Attention. |
| 305 | + ( |
| 306 | + ( |
| 307 | + 1 |
| 308 | + + (args.num_query_groups / args.num_attention_heads) |
| 309 | + # Only half of the attention matrix is non-zero and needs to be multiplied with V. |
| 310 | + + (args.seq_length / args.hidden_size) |
| 311 | + ) |
| 312 | + * query_projection_to_hidden_size_ratio |
| 313 | + ) |
| 314 | + # MLP. |
| 315 | + + ( |
| 316 | + ( |
| 317 | + # Dense. |
| 318 | + (args.ffn_hidden_size * num_dense_layers) |
| 319 | + + |
| 320 | + # MoE. |
| 321 | + ( |
| 322 | + ( |
| 323 | + # Routed experts. |
| 324 | + moe_ffn_hidden_size * num_experts_routed_to |
| 325 | + + |
| 326 | + # Shared experts. |
| 327 | + shared_expert_ffn_hidden_size |
| 328 | + ) |
| 329 | + * num_moe_layers |
| 330 | + ) |
| 331 | + ) |
| 332 | + * gated_linear_multiplier |
| 333 | + / (args.num_layers * args.hidden_size) |
| 334 | + ) |
| 335 | + # Logit. |
| 336 | + + (args.padded_vocab_size / (2 * args.num_layers * args.hidden_size)) |
| 337 | + ) |
| 338 | + ) |
| 339 | + |
| 340 | + # Main entrypoint for FLOPs calculation. |
| 341 | + if args.is_hybrid_model: |
| 342 | + # Calculate the number of each type of layer. |
| 343 | + num_attn_layers, num_mamba_layers, num_mlp_layers = calculate_layer_counts() |
| 344 | + |
| 345 | + # Compute hybrid model FLOPs. |
| 346 | + return hybrid_flops( |
| 347 | + batch_size=batch_size, |
| 348 | + seq_len=args.seq_length, |
| 349 | + hidden_size=args.hidden_size, |
| 350 | + num_attn_layers=num_attn_layers, |
| 351 | + num_mamba_layers=num_mamba_layers, |
| 352 | + num_mlp_layers=num_mlp_layers, |
| 353 | + mamba_state_dim=args.mamba_state_dim, |
| 354 | + mamba_head_dim=args.mamba_head_dim, |
| 355 | + mamba_num_groups=args.mamba_num_groups, |
| 356 | + num_attn_heads=args.num_attention_heads, |
| 357 | + gqa=args.group_query_attention, |
| 358 | + gqa_groups=args.num_query_groups, |
| 359 | + kv_channels=args.kv_channels, |
| 360 | + mlp_expansion=args.ffn_hidden_size / args.hidden_size, |
| 361 | + swiglu=args.swiglu, |
| 362 | + vocab_size=args.padded_vocab_size, |
| 363 | + ) |
| 364 | + else: |
| 365 | + # Compute standard Transformer model FLOPs. |
| 366 | + return transformer_flops() |
| 367 | + |
| 368 | + |
164 | 369 | # The earliest we can measure the start time. |
165 | 370 | _TRAIN_START_TIME = time.time() |
166 | 371 |
|
|
0 commit comments