-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathscaling_transformer.txt
More file actions
1500 lines (1041 loc) · 151 KB
/
scaling_transformer.txt
File metadata and controls
1500 lines (1041 loc) · 151 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
How to scale your model
- Part 0: Intro
```jsx
Much of deep learning still boils down to a kind of black magic, but optimizing the performance of your models doesn't have to - even at huge scale! Relatively simple principles apply everywhere - from dealing with a single accelerator to tens of thousands - and understanding them lets you do many useful things:
- Ballpark how close parts of your model are to their theoretical optimum.
- Make informed choices about different parallelism schemes at different scales (how you split the computation across multiple devices).
- Estimate the cost and time required to train and run large Transformer models.
- Design algorithms that take advantage of specific hardware affordances.
- Design hardware driven by an explicit understanding of what limits current algorithm performance.
Expected background: We’re going to assume you have a basic understanding of LLMs and the Transformer architecture but not necessarily how they operate at scale. You should know the basics of LLM training and ideally have some basic familiarity with JAX. Some useful background reading might include this blog post on the Transformer architecture and the original Transformer paper. Also check this list out for more useful concurrent and future reading.
Goals & Feedback: By the end, you should feel comfortable estimating the best parallelism scheme for a Transformer model on a given hardware platform, and roughly how long training and inference should take. If you don’t, email us or leave a comment! We’d love to know how we could make this clearer.
Why should you care?
Three or four years ago, I don’t think most ML researchers would have needed to understand any of the content in this book. But today even “small” models run so close to hardware limits that doing novel research requires you to think about efficiency at scale.1 A 20% win on benchmarks is irrelevant if it comes at a 20% cost to roofline efficiency. Promising model architectures routinely fail either because they can’t run efficiently at scale or because no one puts in the work to make them do so.
The goal of “model scaling” is to be able to increase the number of chips used for training or inference while achieving a proportional, linear increase in throughput. This is known as “strong scaling”. Although adding additional chips (“parallelism”) usually decreases the computation time, it also comes at the cost of added communication between chips. When communication takes longer than computation we become “communication bound” and cannot scale strongly.2 If we understand our hardware well enough to anticipate where these bottlenecks will arise, we can design or reconfigure our models to avoid them.3
Our goal in this book is to explain how TPU (and GPU) hardware works and how the Transformer architecture has evolved to perform well on current hardware. We hope this will be useful both for researchers designing new architectures and for engineers working to make the current generation of LLMs run fast.
High-Level Outline
The overall structure of this book is as follows:
Section 1 explains roofline analysis and what factors can limit our ability to scale (communication, computation, and memory). Section 2 and Section 3 talk in detail about how TPUs and modern GPUs work, both as individual chips and — of critical importance — as an interconnected system with inter-chip links of limited bandwidth and latency. We’ll answer questions like:
How long should a matrix multiply of a certain size take? At what point is it bound by compute or by memory or communication bandwidth?
How are TPUs wired together to form training clusters? How much bandwidth does each part of the system have?
How long does it take to gather, scatter, or re-distribute arrays across multiple TPUs?
How do we efficiently multiply matrices that are distributed differently across devices?
Figure: a diagram from Section 2 showing how a TPU performs an elementwise product. Depending on the size of our arrays and the bandwidth of various links, we can find ourselves compute-bound (using the full hardware compute capacity) or comms-bound (bottlenecked by memory loading).
Five years ago ML had a colorful landscape of architectures — ConvNets, LSTMs, MLPs, Transformers — but now we mostly just have the Transformer
[1]
. We strongly believe it’s worth understanding every piece of the Transformer architecture: the exact sizes of every matrix, where normalization occurs, how many parameters and FLOPs4 are in each part. Section 4 goes through this “Transformer math” carefully, showing how to count the parameters and FLOPs for both training and inference. This tells us how much memory our model will use, how much time we’ll spend on compute or comms, and when attention will become important relative to the feed-forward blocks.
Figure: a standard Transformer layer with each matrix multiplication (matmul) shown as a dot inside a circle. All parameters (excluding norms) are shown in purple. Section 4 walks through this diagram in more detail.
Section 5: Training and Section 7: Inference are the core of this essay, where we discuss the fundamental question: given a model of some size and some number of chips, how do I parallelize my model to stay in the “strong scaling” regime? This is a simple question with a surprisingly complicated answer. At a high level, there are 4 primary parallelism techniques used to split models over multiple chips (data, tensor, pipeline and expert), and a number of other techniques to reduce the memory requirements (rematerialisation, optimizer/model sharding (aka ZeRO), host offload, gradient accumulation). We discuss many of these here.
We hope by the end of these sections you should be able to choose among them yourself for new architectures or settings. Section 6 and Section 8 are practical tutorials that apply these concepts to LLaMA-3, a popular open-source model.
Finally, Section 9 and Section 10 look at how to implement some of these ideas in JAX and how to profile and debug your code when things go wrong.
Throughout we try to give you problems to work for yourself. Please feel no pressure to read all the sections or read them in order. And please leave feedback. For the time being, this is a draft and will continue to be revised. Thank you!
We’d like to acknowledge James Bradbury and Blake Hechtman who derived many of the ideas in this doc.
Without further ado, here is Section 1 about TPU rooflines.
Links to Sections
This series is probably longer than it needs to be, but we hope that won’t deter you. The first three chapters are preliminaries and can be skipped if familiar, although they introduce notation used later. The final three parts might be the most practically useful, since they explain how to work with real models.
Part 1: Preliminaries
Chapter 1: A Brief Intro to Roofline Analysis. Algorithms are bounded by three things: compute, communication, and memory. We can use these to approximate how fast our algorithms will run.
Chapter 2: How to Think About TPUs. How do TPUs work? How does that affect what models we can train and serve?
Chapter 3: Sharded Matrices and How to Multiply Them. Here we explain model sharding and multi-TPU parallelism by way of our favorite operation: (sharded) matrix multiplications.
Part 2: Transformers
Chapter 4: All the Transformer Math You Need to Know. How many FLOPs does a Transformer use in its forward and backwards pass? Can you calculate the number of parameters? The size of its KV caches? We work through this math here.
Chapter 5: How to Parallelize a Transformer for Training. FSDP. Megatron sharding. Pipeline parallelism. Given some number of chips, how do I train a model of a given size with a given batch size as efficiently as possible?
Chapter 6: Training LLaMA 3 on TPUs. How would we train LLaMA 3 on TPUs? How long would it take? How much would it cost?
Chapter 7: All About Transformer Inference. Once we’ve trained a model, we have to serve it. Inference adds a new consideration — latency — and changes up the memory landscape. We’ll talk about how disaggregated serving works and how to think about KV caches.
Chapter 8: Serving LLaMA 3 on TPUs. How much would it cost to serve LLaMA 3 on TPU v5e? What are the latency/throughput tradeoffs?
Part 3: Practical Tutorials
Chapter 9: How to Profile TPU Code. Real LLMs are never as simple as the theory above. Here we explain the JAX + XLA stack and how to use the JAX/TensorBoard profiler to debug and fix real issues.
Chapter 10: Programming TPUs in JAX. JAX provides a bunch of magical APIs for parallelizing computation, but you need to know how to use them. Fun examples and worked problems.
Chapter 11: Conclusions and Further Reading. Closing thoughts and further reading on TPUs and LLMs.
```
- Part 1: **All About Rooflines**
```jsx
When we run algorithms on hardware, we're bounded by three things: how fast our computer can do math (OPs/second), the bandwidth available for moving data around (bytes/second), and the total memory available to store data (bytes). These “roofline” constraints let us upper and lower bound the time of a given computation.
Where Does the Time Go?
Let's start with an extremely simple question: why does an algorithm take 50 ms instead of 50 s or 5 ms ? What is actually happening within the model that takes substantial time and how long should we expect it to take?
Computation: A deep learning model is effectively a bunch of matrix multiplications, each composed of floating-point multiplication and addition 'operations' (FLOPs). Our accelerator speed determines how long these take to compute:
$$
T_{\text {math }}=\frac{\text { Computation FLOPs }}{\text { Accelerator FLOPs } / \mathrm{s}}
$$
For instance, an NVIDIA H100 can perform about 9.89 e 14 bfloat $16^1$ FLOPs/s while a TPU v6e can perform 9.1e14 FLOPs/s. That means doing 1 e 12 FLOPs on an H100 will take (roughly) $1 \mathrm{e} 12 / 9.89 \mathrm{e} 14=1.01 \mathrm{~ms}$ and $1 \mathrm{e} 12 / 9.1 \mathrm{e} 14=1.1 \mathrm{~ms}$ on a TPU v6e. ${ }^2$
Communication within a chip: Within an accelerator, tensors need to be transferred between on-chip memory (HBM) and the compute cores. You'll see the bandwidth of this link referred to as "HBM bandwidth" 3 On an H100, this is about $3.35 \mathrm{~TB} / \mathrm{s}$ and on TPU v6e this is about $1.6 \mathrm{~TB} / \mathrm{s}$.
Communication between chips: When we distribute a model across multiple accelerators, tensors frequently need to be transferred between them. There are often a few options for this on our hardware (ICI, DCN, and PCle), each with different bandwidths.
Whether the communication is within a chip or between chips, we measure this in bytes/s and estimate the total communication time with:
$$
T_{\text {comms }}=\frac{\text { Communication Bytes }}{\text { Network/Memory Bandwidth Bytes/s }}
$$
Typically (but not always), computation within a single chip can be overlapped with communication within a chip and between chips. This means we can lower-bound training and inference time by using the maximum of computation and communication time. We can also upper-bound with their sum. In practice, we optimize against the maximum as the algebra is simpler and we can usually come close to this bound by overlapping our communication and computation. If we optimize with the maximum in mind then the lower and upper bounds differ by at most a factor of 2 since $T_{\text {math }}+T_{\text {comms }} \leq 2 * \max \left(T_{\text {math }}, T_{\text {comms }}\right)$. We then increase accuracy beyond this by modeling 'overlap regions' and overheads, which can be informed by profiling your specific model and target system.
$$
\begin{gathered}
T_{\text {lower }}=\max \left(T_{\text {math }}, T_{\text {comms }}\right) \\
T_{\text {upper }}=T_{\text {math }}+T_{\text {comms }}
\end{gathered}
$$
If we assume we can perfectly overlap communication and computation, when $T_{\text {math }}>T_{\text {comms }}$, we see full utilization from our hardware. We call this being "compute-bound". When $T_{\text {comms }}>T_{\text {math }}$, we tend to be "communication-bound" and at least some fraction of our accelerator FLOPs/s is wasted waiting for data to be passed around. One way to tell if an operation will be compute or communication-bound is to look at its "arithmetic intensity" or
"operational intensity".
Definition: the arithmetic intensity of an algorithm is given by the ratio of the total FLOPs it performs to the number of bytes it needs to communicate - either within a chip or between chips.
$$
\text { Arithmetic Intensity }=\frac{\text { Computation FLOPs }}{\text { Communication Bytes }}
$$
Arithmetic intensity measures the "FLOPs per byte" of a given operation. To a first order, when our arithmetic intensity is high, $T_{\text {math }}$ is large compared to $T_{\text {comms }}$ and we typically use most of the available FLOPs. When the opposite is true, we spent more time on comms and waste FLOPs. The point where this crossover happens is the "peak arithmetic intensity" of our hardware, the ratio of peak accelerator FLOPs/s to accelerator bandwidth.
$$
\begin{aligned}
T_{\text {math }}>T_{\text {comms }} & \Leftrightarrow \frac{\text { Computation FLOPs }}{\text { Accelerator FLOPs } / \mathrm{s}}>\frac{\text { Communication Bytes }}{\text { Bandwidth Bytes } / \mathrm{s}} \\
& \Leftrightarrow \frac{\text { Computation FLOPs }}{\text { Communication Bytes }}>\frac{\text { Accelerator FLOPs } / \mathrm{s}}{\text { Bandwidth Bytes } / \mathrm{s}}
\end{aligned}
$$
$\Leftrightarrow$ Intensity(Computation) $>$ Intensity(Accelerator)
The quantity Intensity(Accelerator) is the arithmetic intensity at which our accelerator achieves its peak FLOPs/s. For the TPU v5e MXU, this is about 240 FLOPs/byte ${ }^4$, since the TPU can perform 1.97e14 FLOPs/s and load 8.2e11 bytes/s from HBM. That means if an algorithm has a lower arithmetic intensity than $240^5$ FLOPs/byte, it will be bound by byte loading and thus we won't make good use of our hardware. Let's look at one such example:
Example (dot product): to compute the dot product of two vectors in bfloat16 precision, $\mathrm{x} \cdot$ $\mathrm{y}: \mathrm{bf} 16[\mathrm{~N}], \mathrm{bf} 16[\mathrm{~N}] \rightarrow \mathrm{bf} 16[1]$, we need to load $x$ and $y$ from memory, each of which has $2 * N=2 N$ bytes, perform $N$ multiplications and $N-1$ additions, and write 2 bytes back into HBM
$$
\text { Intensity }(\text { dot product })=\frac{\text { Total FLOPs }}{\text { Total Bytes }}=\frac{N+N-1}{2 N+2 N+2}=\frac{2 N-1}{4 N+2} \rightarrow \frac{1}{2}
$$
as $N \rightarrow \infty$. So the dot product has an arithmetic intensity of $\frac{1}{2}$ or, put another way, the dot product does 0.5 floating point operations per byte loaded. This means our arithmetic intensity is lower than that of our hardware and we will be communication-bound. ${ }^6$
Visualizing rooflines
We can visualize the tradeoff between memory and compute using a roofline plot, which plots the peak achievable FLOPs/s (throughput) of an algorithm on our hardware (the y-axis) against the arithmetic intensity of that algorithm (the x-axis). Here’s an example log-log plot:
Figure: an example roofline plot showing two algorithms with different arithmetic intensities (Algo 1 and Algo 2) and their corresponding theoretical peak throughput under different bandwidths (BW1 and BW2). In the red area, an algorithm is bandwidth bound at both bandwidths and is wasting some fraction of the hardware's peak FLOPs/s. The yellow area is bandwidth-bound only at the lower bandwidth (BW1). The green area is compute-bound at all bandwidths. Here, we are using the peak FLOPs/s of the accelerator and increasing bandwidth or improving intensity yield no benefit.
Above, as the intensity increases (moving left to right), we initially see a linear increase in the performance of our algorithm (in FLOPs/s) until we hit the critical arithmetic intensity of the hardware, 240 in the case of the TPU v5e. Any algorithm with a lower intensity will be bandwidth (BW) bound and limited by the peak memory bandwidth (shown in red). Any algorithm to the right will fully utilize our FLOPs (shown in green). Here, Algo 1 is comms-bound and uses only a fraction of the total hardware FLOPs/s. Algo 2 is compute-bound. We can generally improve the performance of an algorithm either by increasing its arithmetic intensity or by increasing the memory bandwidth available (moving from BW1 to BW2).
Matrix multiplication
Let's look at our soon-to-be favorite algorithm: matrix multiplication (aka matmul). We write $X * Y \rightarrow Z$ where $X$ has shape $\operatorname{bf16}[B, D], Y$ has shape $\operatorname{bf16}[D, F]$, and $Z$ has shape bf16 $[B, F]$. To do the matmul we need to load $2 D F+2 B D$ bytes, perform $2 B D F$ FLOPs, and write $2 B F$ bytes back. ${ }^{78}$ Thus:
$$
\operatorname{Intensity}(\text { matmul })=\frac{2 B D F}{2 B D+2 D F+2 B F}=\frac{B D F}{B D+D F+B F}
$$
We can get a nice simplification if we assume our local "batch size" $B$ is small relative to $D$ and $F$. Then we get
$$
\begin{gathered}
\frac{B D F}{B D+D F+B F} \cong \frac{B D F}{D F}=B \\
\text { Intensity (matmul) }>\text { Intensity }(\mathrm{TPU}) \Longrightarrow B>\frac{1.97 e 14}{8.20 e 11}=240
\end{gathered}
$$
This is a reasonable assumption for Transformer matmuls since for most of our models we have our local batch size in tokens $B<1024$ but $D$ and $F>8000$. Thus we become compute-bound when our local batch size is greater than 240 tokens, a very simple rule!
Takeaway: for a bfloat16 matmul to be compute-bound on most TPUs, we need our local batch size in tokens to be greater than 240.
This comes with a few notable caveats we'll explore in the problems below, particularly with respect to quantization (e.g., if we quantize our activations but still do full-precision FLOPs), but it's a good rule to remember. For GPUs, this number is slightly higher (closer to 300), but the same conclusion generally holds. We'll discuss the lower-level GPU and TPU details in the next section.
Network communication rooflines
All the rooflines we've discussed so far have been memory-bandwidth rooflines, all within a single chip. This shouldn't be taken as a rule. In fact, most of the rooflines we'll care about in this book involve communication between chips: usually matrix multiplications that involve matrices sharded across multiple TPUs.
To pick a somewhat contrived example, say we want to multiply two big matrices $X \sim$ bfloat16 $[\mathrm{B}, \mathrm{D}]$ and $Y \sim$ bfloat16 $[\mathrm{D}, \mathrm{F}]$ which are split evenly across 2 TPUs/GPUs (along the $D$ dimension). To do this multiplication (as we'll see in Section 3), we can multiply half of each matrix on each TPU $(A=X[:,: D / / 2] @ Y[: D / / 2,:]$ on TPU 0 and $B=$ X[:, D // 2:] @ Y[D // 2:, :] on TPU 1) and then copy the resulting "partial sums" to the other TPU and add them together. Say we can copy 4.5 e 10 bytes in each direction and perform 1.97e14 FLOPs/s on each chip. What are $T_{\text {math }}$ and $T_{\text {comms }}$ ?
$T_{\text {math }}$ is clearly half of what it was before, since each TPU is doing half the work, i.e. ${ }^9$
$$
T_{\mathrm{math}}=\frac{2 B D F}{2 \cdot \text { Accelerator FLOPs } / \mathrm{s}}=\frac{B D F}{1.97 e 14}
$$
Now what about $T_{\text {comms }}$ ? This now refers to the communication time between chips! This is just the total bytes sent divided by the network bandwidth, i.e.
$$
T_{\text {comms }}=\frac{2 B F}{\text { Network Bandwidth }}=\frac{2 B F}{4.5 e 10}
$$
Therefore we become compute-bound (now with respect to the inter-chip network) when Intensity(matmul (2-chips)) $>$ Intensity(TPU w.r.t. inter-chip network) or equivalently when $\frac{B D F}{2 B F}=\frac{D}{2}>\frac{1.97 e 14}{4.5 e 10}=4377$ or $D>8755$. Note that, unlike before, the critical threshhold now depends on $D$ and not $B$ ! Try to think why that is. This is just one such example, but we highlight that this kind of roofline is critical to knowing when we can parallelize an operation across multiple TPUs.
A Few Problems to Work
Question 1 [int8 matmul]: Say we want to do $X[B, D] \cdot{ }_D Y[D, F] \rightarrow Z[B, F]$ in int8 precision (1 byte per parameter) instead of bfloat16. ${ }^{10}$
1. How many bytes need to be loaded from memory? How many need to be written back to memory?
2. How many total OPs are performed?
3. What is the arithmetic intensity?
4. What is a roofline estimate for $T_{\text {math }}$ and $T_{\text {comms }}$ ? What are reasonable upper and lower bounds for the runtime of the whole operation?
Assume our HBM bandwidth is 8.1 e 11 bytes/s and our int8 peak OPs/s is 3.94 e 14 .
- Click here for the answer.
1. Because we're storing our parameters in int8, we have 1 byte per parameter, so we have $B D+D F$ bytes loaded from HBM and $B F$ written back.
2. This is the same as in bfloat16, but in theory int8 OPs/s should be faster. So this is still $2 B D F$ FLOPs.
3. Arithmetic intensity is $2 B D F /(B D+D F+B F)$. If we make the same assumption as above about $B \ll D$ and $B \ll F$, we get an arithmetic intensity of $2 B$, meaning our rule becomes $B>$ HBM int8 arithmetic intensity $/ 2$. Using the numbers given, this int8 intensity is $3.94 \mathrm{e} 14 / 8.1 \mathrm{e} 11=486$, so the rule is $B>486 / 2=243$. Note that this is basically unchanged!
4. $T_{\text {math }}=2 B D F / 3.94 e 14$ and $T_{\text {comms }}=(B D+D F+B F) / 8.1 e 11$, so a reasonable lower bound is $\max \left(T_{\text {math }}, T_{\text {comms }}\right)$ and an upper bound is $T_{\text {math }}+T_{\text {comms }}$.
Question 2 [int8 + bf16 matmul]: In practice we often do different weight vs. activation quantization, so we might store our weights in very low precision but keep activations (and compute) in a higher precision. Say we want to quantize our weights in int8 but keep activations (and compute) in bfloat16. At what batch size do we become compute bound? Assume 1.97e14 bfloat16 FLOPs/s.
Hint: this means specifically bfloat16 [B, D] * int8 [D, F] -> bfloat16 [B, F] where $B$ is the "batch size".
Click here for the answer.
Again assuming B is small, we have 2BDF bfloat16 FLOPs but only DF weights (instead of 2DF in bfloat16). This means we become compute-bound when $2 B>240$ or $B>120$. This is a lot lower, meaning if we can do int8 weight quantization (which is fairly easy to do) but still do bfloat16 FLOPs, we get a meaningful win in efficiency (although int8 OPs would be better).
Question 3: For the problem above, make a roofline plot of peak FLOPs vs. B for several values of $D$ and $F$.
Question 4: What if we wanted to perform $\operatorname{int} 8[\mathrm{~B}, \mathrm{D}] *_D \operatorname{int} 8[\mathrm{~B}, \mathrm{D}, \mathrm{F}] \rightarrow \operatorname{int} 8[\mathrm{~B}, \mathrm{~F}]$ where we imagine having a different matrix for each batch element. What is the arithmetic intensity of this operation?
Click here for the answer.
Let's start by looking at the total FLOPs and comms.
1. Total FLOPs: the FLOPs is basically the same, since we're doing the same number of $B D \times D F$ matmuls (this is discussed more in section 4). So this is just $2 B D F$.
2. Total comms: we have a lot more comms here: $B D+B D F+B F$.
3. Therefore, our arithmetic intensity is now actually $2 B D F /(B D+B D F+B F)$. Since $B D F$ dominates the denominator, this is roughly 2 . So instead of it depending on the batch size, this is essentially constant. This is bad because it means we'll basically always be comms bound no matter what.
Problem 5 [Memory Rooflines for GPUs]: Using the spec sheet provided by NVIDIA for the H100, calculate the batch size at which a matrix multiplication will become compute-bound. Note that the Tensor Core FLOPs numbers are twice the true value since they're only achievable with structured sparsity.
Click here for the answer.
From the spec sheet, we see that the reported bfloat16 FLOPs value is $1.979 \mathrm{e} 15 \mathrm{FLOPs} / \mathrm{s}$ with an asterisk noting "with sparsity". The true value is half this without sparsity, meaning close to 1 e $15 \mathrm{FLOPs} / \mathrm{s}$. The memory bandwidth is $3.35 \mathrm{~TB} / \mathrm{s}$, or 3.35 e 12 bytes / second. Thus $B_{\text {crit }}$ is $1 \mathrm{e} 15 / 3.35 \mathrm{e} 12=298$, rather similar to the TPU.
```
- Part 2: TPU
```jsx
How to Think About TPUs
Part 2 of How To Scale Your Model (Part 1: Rooflines | Part 3: Sharding)
This section is all about how TPUs work, how they're networked together to enable multi-chip training and inference, and how this affects the performance of our favorite algorithms. There's even some good stuff for GPU users too!
What ls a TPU?
A TPU is basically a compute core that specializes in matrix multiplication (called a TensorCore) attached to a stack of fast memory (called high-bandwidth memory or HBM) [1]. Here's a diagram:
Figure: the basic components of a TPU chip. The TensorCore is the gray left-hand box, containing the matrixmultiply unit (MXU), vector unit (VPU), and vector memory (VMEM).
You can think of the TensorCore as basically just being a really good matrix multiplication machine, but it has a few other functions worth noting. The TensorCore has three key units:
- The MXU (Matrix Multiply Unit) is the core of the TensorCore. For most TPU generations, it performs one bfloat16 [8,128] @ bf16 [128,128] -> f32 [8,128] matrix multiply ${ }^1$ every 8 cycles using a systolic array (see Appendix B for details).
- This is about 5 e 13 bf16 FLOPs/s per MXU at 1.5 GHz on TPU v5e. Most TensorCores have 2 or 4 MXUs, so e.g. the total bf16 FLOPs/s for TPU v5e is 2 e 14 .
- TPUs also support lower precision matmuls with higher throughput (e.g. each TPU v5e chip can do $4 \mathrm{e} 14 \mathrm{int} 8 \mathrm{OPs} / \mathrm{s}$ ).
- The VPU (Vector Processing Unit) performs general mathematical operations like ReLU activations or pointwise addition or multiplication between vectors. Reductions (sums) are also performed here. Appendix C provides more details.
- VMEM (Vector Memory) is an on-chip scratchpad located in the TensorCore, close to the compute units. It is much smaller than HBM (for example, 128 MiB on TPU v5e) but has a much higher bandwidth to the MXU. VMEM operates somewhat like an L1/L2 cache on CPUs but is much larger and programmer-controlled. Data in HBM needs to be copied into VMEM before the TensorCore can do any computation with it.
TPUs are very, very fast at matrix multiplication. It's mainly what they do and they do it well. TPU v5p, one of the most powerful TPUs to date, can do $2.5 \mathrm{e} 14 \mathrm{bf} 16 \mathrm{FLOPs} /$ second / core or 5 e 14 bf16 FLOPs / sec / chip. A single pod of 8960 chips can do 4 exaflops / second. That's a lot. That's one of the most powerful supercomputers in the world. And Google has a lot of them. ${ }^2$
The diagram above also includes a few other components like SMEM and the scalar unit, which are used for control flow handling and are discussed briefly in Appendix C, but aren't crucial to understand. On the other hand, HBM is important and fairly simple:
- HBM (High Bandwidth Memory) is a big chunk of fast memory that stores tensors for use by the TensorCore. HBM usually has capacity on the order of tens of gigabytes (for example, TPU v5e has 16 GiB of HBM).
- When needed for a computation, tensors are streamed out of HBM through VMEM (see below) into the MXU and the result is written from VMEM back to HBM.
- The bandwidth between HBM and the TensorCore (through VMEM) is known as "HBM bandwidth" (usually around $1-2 \mathrm{~TB} / \mathrm{sec}$ ) and limits how fast computation can be done in memory-bound workloads.
Generally, all TPU operations are pipelined and overlapped. To perform a matmul $X \cdot A \rightarrow Y$, a TPU would first need to copy chunks of matrices $A$ and $X$ from HBM into VMEM, then load them into the MXU which multiplies chunks of $8 \times 128$ (for $X$ ) and $128 \times 128$ (for $A$ ), then copy the result chunk by chunk back to HBM. To do this efficiently, the matmul is pipelined so the copies to/from VMEM are overlapped with the MXU work. This allows the MXU to continue working instead of waiting on memory transfers, keeping matmuls compute-bound, not memory-bound.
Here's an example of how you might perform an elementwise product from HBM:
Figure: an animation showing a pointwise product performed on TPU, with bytes loaded from HBM. Note how bytes are streamed out of memory in chunks and partial results are pipelined back without waiting for the full array to be materialized.
A matmul would look nearly identical except it would load into the MXU instead of the VPU/Vector unit, and the loads and stores would occur in a different order, since the same weight chunk is used for multiple chunks of activations. You can see chunks of data streaming into VMEM, then into the VREGs (vector registers), then into the Vector Unit, then back into VMEM and HBM. As we're about to see, if the load from HBM to VMEM is slower than the FLOPs in the Vector Unit (or MXU), we become "bandwidth bound" since we're starving the VPU or MXU of work.
Key takeaway: TPUs are very simple. They load weights from HBM into VMEM, then from VMEM into a systolic array which can perform around 200 trillion multiply-adds per second. The HBM $\leftrightarrow$ VMEM and VMEM $\leftrightarrow$ systolic array bandwidths set fundamental limits on what computations TPUs can do efficiently.
VMEM and arithmetic intensity: VMEM is much smaller than HBM but it has a much higher bandwidth to the MXU. As we saw in Section 1, this means if an algorithm can fit all its inputs/outputs in VMEM, it's much less likely to hit communication bottlenecks. This is particularly helpful when a computation has poor arithmetic intensity: VMEM bandwidth is around $22 x$ higher than HBM bandwidth which means an MXU operation reading from/writing to VMEM requires an arithmetic intensity of only 10-20 to achieve peak FLOPs utilization. That means if we can fit our weights into VMEM instead of HBM, our matrix multiplications can be FLOPs bound at much smaller batch sizes. And it means algorithms that fundamentally have a lower arithmetic intensity can still be efficient. VMEM is just so small this is often a challenge. ${ }^3$
A TPU chip typically (but not always) consists of two TPU cores which share memory and can be thought of as one large accelerator with twice the FLOPs (known as a "megacore" configuration). This has been true since TPU v4. Older TPU chips they have separate memory and are regarded as two separate accelerators (TPU v3 and older). Inference-optimized chips like the TPU v5e only have one TPU core per chip.
Chips are arranged in sets of 4 on a 'tray' connected to a CPU host via PCle network. This is the format most readers will be familiar with, 4 chips ( 8 cores, though usually treated as 4 logical megacores) exposed through Colab or a single TPU-VM. For inference chips like the TPU v5e, we have 2 trays per host, instead of 1 , but also only 1 core per chip, giving us 8 chips $=8$ cores. ${ }^4$
PCIe bandwidth is limited: Like the HBM $\leftrightarrow$ VMEM link, the CPU $\leftrightarrow$ HBM PCle connection has a specific bandwidth that limits how quickly you can load from host memory to HBM or viceversa. PCle bandwidth for TPU v4 is 16GB / second each way, for example, so close to 100x slower than HBM. We can load/offload data into the host (CPU) RAM, but not very quickly.
TPU Networking
Chips are connected to each other through the ICI network in a Pod. In older generations (TPU v2 and TPU v3), inference chips (e.g., TPU v5e), and Trilium (TPU v6e), ICI ("inter-chip interconnects") connects the 4 nearest neighbors (with edge links to form a 2D torus). TPU v4 and TPU v5p are connected to the nearest 6 neighbors (forming a 3D torus). Note these connections do not go through their hosts, they are direct links between chips.
The toroidal structure reduces the maximum distance between any two nodes from $N$ to $N / 2$, making communication much faster. TPUs also have a "twisted torus" configuration that wraps the torus in a Mobius-strip like topology to further reduce the average distance between nodes.
TPU pods (connected by ICI) can get really big: the maximum pod size (called a superpod) is $16 \times 16 \times 16$ for TPU v4 and $16 \times 20 \times 28$ for TPU v5p. These large pods are composed of reconfigurable cubes of $4 \times 4 \times 4$ chips connected by optical wraparound links ${ }^5$ that we can reconfigure to connect very large topologies.Smaller topologies (e.g. 2x2x1, 2x2x2) can also be requested, albeit with no wraparounds. This is an important caveat, since it typically doubles the time of most communication. Any multiple of a full cube (e.g. 4x4x4 or 4x4x8) will have wraparounds provided by the optical switches.6
TPU v5e and Trillium pods consist of a single $16 \times 16$ 2D torus with wraparounds along any axis of size 16 (meaning an $8 \times 16$ has a wraparound on the long axis). TPUs v5e and v6e (Trillium) cannot expand beyond a $16 \times 16$ torus but pods can still communicate with each other over standard data-center networking (DCN), which connects TPU hosts to each other. Again, smaller topologies can be requested without wraps on dims $<16$.
This nearest-neighbor connectivity is a key difference between TPUs and GPUs. GPUs are connected with a hierarchy of switches that approximate a point-to-point connection between every GPU, rather than using local connections like a TPU. Typically, GPUs within a node (8 GPUs for H 100 or as many as 500 for B200) are directly connected, while larger topologies require $\mathrm{O}(\log (\mathrm{N}))$ hops between each GPU. On the one hand, that means GPUs can send arbitrary data within a node in a single low-latency hop. On the other hand, TPUs are dramatically cheaper (since NVLink switches are expensive) and simpler to wire together, and can scale to much larger topologies because the number of links per device and the bandwidth per device is constant.
ICI is very fast relative to DCN, but is still slower than HBM bandwidth. For instance, a TPU v5p has:
- 2.5e12 bytes/s (2.5 TB/s) of HBM bandwidth per chip.
- $9 \mathrm{e} 10 \mathrm{bytes} / \mathrm{s}\left(90^7 \mathrm{~GB} / \mathrm{s}\right)$ of ICI bandwidth per axis, with 3 axes per chip.
- 2.5 e 10 bytes/s ( $25 \mathrm{~GB} / \mathrm{s}$ ) of DCN (egress) bandwidth per host. Since we typically have 8 TPUs per host, this is really closer to 3.1 e 9 bytes / s / chip.
This means that when we split models across multiple chips, we need to be careful to avoid bottlenecking the MXU with slower cross-device communication.
Multi-slice training: A set of ICI-connected TPUs is called a slice. Different slices can be connected between each other using DCN, for instance to link slices on different pods. Since DCN is a much slower connection than ICI, one should try to limit how much our computation has to wait for data from DCN. DCN is host-to-host, so to transfer buffers from TPU to TPU over DCN, we first need to transfer over PCle to the host, then egress over the network, then ingress over the target host network, then over PCle into HBM.
Key Takeaways
- TPUs are simple and can in most cases be thought of as a matrix multiply unit connected to memory (super fast), other chips over ICI (rather fast), and the rest of the datacenter over DCN (somewhat fast).
- Communication is limited by our various network bandwidths in order of speed:
- HBM bandwidth: Between a TensorCore and its associated HBM.
- ICI bandwidth: Between a TPU chip and its nearest 4 or 6 neighbors.
- PCle bandwidth: Between a CPU host and its associated tray(s) of chips.
- DCN bandwidth: Between multiple CPU hosts, typically hosts not connected by ICI.
- Within a slice, TPUs are only connected to their nearest neighbors via ICI. This means communication over ICI between distant chips in a slice needs to hop over the intervening chips first.
Multi-slice training: A set of ICI-connected TPUs is called a slice. Different slices can be connected between each other using DCN, for instance to link slices on different pods. Since DCN is a much slower connection than ICI, one should try to limit how much our computation has to wait for data from DCN. DCN is host-to-host, so to transfer buffers from TPU to TPU over DCN, we first need to transfer over PCle to the host, then egress over the network, then ingress over the target host network, then over PCIe into HBM.
Key Takeaways
- TPUs are simple and can in most cases be thought of as a matrix multiply unit connected to memory (super fast), other chips over ICI (rather fast), and the rest of the datacenter over DCN (somewhat fast).
- Communication is limited by our various network bandwidths in order of speed:
- HBM bandwidth: Between a TensorCore and its associated HBM.
- ICI bandwidth: Between a TPU chip and its nearest 4 or 6 neighbors.
- PCle bandwidth: Between a CPU host and its associated tray(s) of chips.
- DCN bandwidth: Between multiple CPU hosts, typically hosts not connected by ICI.
- Within a slice, TPUs are only connected to their nearest neighbors via ICI. This means communication over ICI between distant chips in a slice needs to hop over the intervening chips first.
- Weight matrices need to be padded to at least size 128 (256 on TPU v6) in both dimensions to fill up the MXU (in fact, smaller axes are padded to 128).
- Lower precision matrix multiplication tends to be faster. TPUs can do int8 or int4 FLOPs roughly $2 x / 4 x$ faster than bfloat16 FLOPs for generations that support it. VPU operations are still performed in fp32.
- To avoid bottlenecking the TPU compute unit, we need to make sure the amount of communication across each channel is proportional to its speed.
- Here are some specific numbers for our chips:
\begin{tabular}{|l|l|l|l|l|l|l|}
\hline Model & Pod size & Host size & HBM capacity/chip & HBM BW/chip (bytes/s) & FLOPs/s/chip (bf16) & FLOPs/s/chip (int8) \\
\hline TPU v3 & $32 \times 32$ & $4 \times 2$ & 32GB & 9.0e11 & 1.4 e 14 & 1.4 e 14 \\
\hline TPU v4p & $16 \times 16 \times 16$ & $2 \times 2 \times 1$ & 32GB & 1.2 e 12 & 2.75 e 14 & 2.75 e 14 \\
\hline TPU v5p & $16 \times 20 \times 28$ & $2 \times 2 \times 1$ & 96GB & 2.8e12 & 4.59 e 14 & 9.18 e 14 \\
\hline TPU v5e & $16 \times 16$ & $4 \times 2$ & 16GB & 8.1 e 11 & 1.97 e 14 & 3.94 e 14 \\
\hline TPU v6e & $16 \times 16$ & $4 \times 2$ & 32GB & 1.6e12 & 9.20 e 14 & 1.84 e 15 \\
\hline
\end{tabular}
Host size refers to the topology of TPUs connected to a single host (e.g. TPU v5e has a single CPU host connected to 8 TPUs in a $4 \times 2$ topology). Here are interconnect figures:
\begin{tabular}{|l|l|l|}
\hline Model & ICI BW/link (one-way, bytes/s) & ICI BW/link (bidi, bytes/s) \\
\hline TPU v3 & 1 e 11 & 2e11 \\
\hline TPU v4p & 4.5 e 10 & 9e10 \\
\hline TPU v5p & 9e10 & 1.8e11 \\
\hline TPU v5e & 4.5 e 10 & 9 e 10 \\
\hline TPU v6e & 9e10 & 1.8e11 \\
\hline
\end{tabular}
We include both one-way (unidirectional) bandwidth and bidi (bidirectional) bandwidth since unidirectional bandwidth is more true to the hardware but bidirectional bandwidth occurs more often in equations involving a full ring. ${ }^8$
PCle bandwidth is typically around 1.5 e 10 bytes / second per chip ${ }^9$, while DCN bandwidth is typically around 2.5 e 10 bytes / second per host. We include both unidirectional and bidirectional bandwidth for completeness. Typically bidirectional bandwidth is the more useful number when we have access to a full wraparound ring, while one-way bandwidth is more true to the hardware.
Worked Problems
These numbers are a little dry, but they let you make basic roofline estimates for model performance. Let's work a few problems to explain why this is useful. You'll see more examples in Part 3.
Question 1 [bounding LLM latency]: Say you want to sample from a 200B parameter model in bf16 that's split across 32 TPU v4p. How long would it take to load all the parameters from HBM into the systolic array? Hint: use the numbers above.
Click here for the answer.
Answer: We're loading sizeof $(\mathrm{bf} 16) * 200 \mathrm{e} 9=400 \mathrm{e} 9$ bytes on 32 chips, meaning 12.5 e 9 bytes / chip, each with an HBM bandwidth of 1.23 e 12 . So the load takes around 10 ms .
That's pretty cool, because that's a reasonable lower bound on the latency of sampling from the model. Each sampling step needs to load all parameters from HBM, so it cannot take less than 10 ms . In practice, at small batch sizes, this is close to being achievable.
Question 2 [TPU details]: Consider a full TPU v5e pod. How many total CPU hosts are there? How many TPU TensorCores? What is the total FLOPs/s for the whole pod? What is the total HBM? Do the same exercise for TPU v5p pod.
Click here for the answer.
Answer: For TPU v5e, each pod is $16 \times 16$ and each host is a $4 \times 2$ slice, so we have $16 * 16$ / $8=32$ hosts. For TPU v5e, each TPU has only one core, so we have 256 TensorCores. The total FLOPs/s is $16 * 16 * 2 \mathrm{e} 14=5.1 \mathrm{e} 16$ in bfloat 16 . Each chip has 16 GB of HBM, so that's $256 * 16=4$ TB of memory.
For a full TPU v5p pod, we have $16 \times 20 \times 28$ chips and each host is $2 \times 2 \times 1$, so we have $16 * 20 * 28 / 2 * 2=2,240$ hosts. For TPU v5p, each TPU has two TensorCores, so we have $8960 * 2=17,920$ cores. The total FLOPs/s is $8960 * 4.5 \mathrm{e} 14=4 \mathrm{e} 18$ in bfloat16. Each chip has 96GB of HBM, so that's $8960 * 96=860 \mathrm{~TB}$ of memory.
Question 3 [PCle operational intensity]: Imagine we're forced to store a big weight matrix $A$ of type bfloat16 $[D, F]$, and a batch of activations $x$ of type bfloat16 $[B, D]$ in host DRAM and want to do a matrix multiplication on them. This is running on a single host, and we're using a single TPU v6e chip attached to it. You can assume $B \ll D$, and $F=4 D$ (we'll see in future chapters why these are reasonable assumptions). What is the smallest batch size $B$ we need to remain FLOPs bound over PCIe? Assume PCle bandwidth of 1.5 e10 bytes / second.
Click here for the answer.
Answer: We have to perform $2 B D F$ floating point operations, and each chip can perform 9.2 e 14 floating point operations per second. This then requires $2 B D F / 9.2 e 14$ seconds to perform. We have to load $2 D F+2 B D$ bytes from DRAM, and write $2 B F$ bytes back to it. We are bottlenecked by PCle transfer speeds, so we need $2 \cdot(B D+D F+B F) / 1.5 e 10$ seconds to transfer data to and from the TPU. Since we want computation to take longer than weight loading, assuming we can overlap all weight loading with computation, we want $2 B D F / 9.2 e 14>2 \cdot(B D+D F+B F) / 1.5 e 10$. We can simplify this using our assumptions that $B \ll D$, and $F=4 D$, to get
$$
\frac{8 B D^2}{9.2 e 14}>\frac{8 D^2}{1.5 e 10}
$$
or
$$
B>\frac{9.2 e 14}{1.5 e 10} \simeq 61,000
$$
Question 4 [general matmul latency]: Let's say we want to multiply a weight matrix int8[16384, 4096] by an activation matrix of size int8[B, 4096] where $B$ is some unknown batch size. Let's say we're on 1 TPUv5e to start.
1. How long will this multiplication take as a function of B? Hint: it may help to calculate how long it will take to load the arrays from HBM and how long the multiplication will actually take. Which is bottlenecking you?
2. What if we wanted to run this operation out of VMEM? How long would it take as a function of $B$ ?
Click here for the answer.
Answer: (1) The number of floating point operations we need to perform is $2 \cdot 4096 \cdot 16384 \cdot B=1.3 e 8 \cdot B$. So $T_{\text {math }}=(1.3 e 8 \cdot B) / 3.94 e 14$ seconds. We need to load $16384 \cdot 4096+4096 \cdot B$ bytes from HBM to VMEM, and write back $16384 \cdot B$ bytes from VMEM to HBM. This means $T_{\text {comms }}=(6.7 e 7+2 e 4 \cdot B) / 8.1 e 11$ seconds. Assuming as much overlap of communication and computation as possible, the whole multiplication will take approximately
$$
\max \left\{T_{\text {math }}, T_{\text {comms }}\right\}=\max \left\{\frac{6.7 e 7+2 e 4 \cdot B}{8.1 e 11}, \frac{1.3 e 8 \cdot B}{3.94 e 14}\right\}
$$
We'll be FLOPs-bound when $\frac{6.7 e 7+2 e 4 \cdot B}{8.1 e 11}<\frac{1.3 e 8 \cdot B}{3.94 e 14}$, or equivalently, $B>271$. This is slightly larger than the 240 number we derive below because we factor in the full impact of $D$ and $F$.
(2) If instead we are loading from VMEM, let's consider VMEM bandwidth to the MXU as 22 times the HBM $\leftrightarrow$ VMEM bandwidth. This turns our data loading denominator from 8.1e11 to 1.78 e 13 , and we get $B>11$. _Note that in practice, we cannot dedicate all of our VMEM bandwidth to loading $W$, so in practice it will be closer to 20 .
Question 5 [ICI bandwidth]: Let's say we have a TPU v5e $4 \times 4$ slice. Let's say we want to send an array of type bfloat16 [8, 128, 8192] from $\operatorname{TPU}\{0,0\}$ to $\operatorname{TPU}\{3,3\}$. Let's say the perhop latency for TPU v5e is $1 \mu s$.
1. How soon will the first byte arrive at its destination?
2. How long will the total transfer take?
Click here for the answer.
Answer: In a TPUv5e we have 2D connectivity. Because we have only a $4 \times 4$ slice (with no axes of size 16), we have no wraparound connections. Thus there are two ports from which our target chip can receive data, and likewise two ports from which our source chip can send data. The amount of data we have to transfer is $2 * 8 * 128 * 8192=1.7 \mathrm{e} 7$ bytes. We can transfer from both ports simultaneously (i.e. send half the array right and half down), so we get $2 * 4.5 \mathrm{e} 10=9 \mathrm{e} 10$ bytes transferred per second, which means it'll take about $1.7 \mathrm{e} 7 / 9 \mathrm{e} 10=188 \mathrm{us}$ to transfer the whole array through (assuming we're bandwidth bound). In a $4 \times 4$ slice, we have six hops between chips $(0,0)$ and $(3,3)$, since there are no wraparound links for axes with fewer than 16 chips. Since the latency of each hop is about $1 \mu s$, the first byte will arrive in about 6 us and the total transfer will take 188 us.
Question 6 [pulling it all together, hard]: Imagine you have a big matrix A: int8 [128 * 1024, $128 * 1024$ ] sharded evenly across a TPU v5e $4 \times 4$ slice but offloaded to host DRAM on each chip. Let's say you want to copy the entire array to $\operatorname{TPU}\{0,0\}$ and multiply it by a vector bf16 [8, $128 * 1024$ ]. How long will this take? Hint: use the numbers above.
Click here for the answer.
Answer: Let's start by outlining the operations we have to perform. Our array is about 16 GB . From the table above, a TPU v5e host has a $4 \times 2$ topology, so a $4 \times 4$ has 2 hosts, Thus, since our array is evenly sharded, each host effectively contains a chunk of $1 / 2$ of the array, or 8GB. We need to copy these chunks all to TPU $\{0,0\}$, which gives us two options:
1. We can copy over DCN and then load the entire unsharded array over PCle into HBM.
2. We can load our sharded arrays onto their corresponding TPUs, then perform a gather over ICI, then perform the matmul on $\operatorname{TPU}\{0,0\}$.
It should be clear that option (2) is better. DCN is slow compared to ICI and we'd much prefer to load a big array over many PCIe links rather than just a few (the 8 on host 0). Here's a diagram of part of the system. As described above, note that TPUs are connected to their neighbors by ICI (even across hosts), all TPUs are connected to their host CPU (via PCle), and hosts are connected by DCN.
Now let's work through how long each piece will take:
1. PCIe load: we're loading chunks of $16 \mathrm{~GB} / 2=8 \mathrm{~GB}$ over 8 PCle links, each of which has 1.5 e 10 bytes/second bandwidth. Thus this will take about 66 ms .
2. ICI copy: each TPU now has $16 \mathrm{~GB} / 16=1 \mathrm{~GB}$ of our array. Our ICI bandwidth is 10 e 10 bytes/second per link bidirectional, and you'll notice from the above diagram that only 2 of the 4 ICI links on the TPU v5e are in use in this topology. Since TPU $\{0,0\}$ needs to receive a total of 15 GB along 2 axes at 4.5 e 10 bytes/s/link, we can lower bound the time by $15 \mathrm{e} 9 /(4.5 \mathrm{e} 10 * 2)=167 \mathrm{~ms}$. In practice this probably isn't achievable because the load is very uneven, but it's probably within a factor of 2 . As you'll see in Section 2, performing a full AllGather would also take roughly $16 \mathrm{e} 9 /(4.5 \mathrm{e} 10 * 2)$, so this is close to optimal.
3. HBM $\rightarrow$ MXU load: to perform our final matmul, we need to load these 16 e 9 bytes plus the bf16[8, 128 * 1024] array (another 2MB, so negligible) over HBM bandwidth into the MXU, which will take $16 \mathrm{e} 9 / 8.1 \mathrm{e} 11=19 \mathrm{~ms}$.
4. FLOPs: we're performing a total of $2 \cdot 8 \cdot 128 \cdot 1024 \cdot 128 \cdot 1024=2.7 e 11$ FLOPs, and since we can perform $1.97 \mathrm{e} 14 \mathrm{bf} 16 \mathrm{FLOPs} / \mathrm{s}$, we get 1.3 ms .
An upper bound for the total time is the sum of all of these times, but since the TPU can typically overlap these operations, we can think of this as a pipelining problem that's bottlenecked by the slowest piece. Assuming that's true, then the answer is about 150200 ms .
Appendix A: Let's talk about GPUs
Compared to TPUs, GPUs have a simpler communication model and a more complicated programming model.
Overview of the compute model:
- GPUs are conceptually similar to TPUs: they also function as an accelerator attached to a CPU. Many components are roughly analogous:
\begin{tabular}{|l|l|}
\hline TPU & GPU \\
\hline Tensor Core & SM ('Streaming Multiprocessor') \\
\hline HBM & DRAM \\
\hline VPU & Tensor Cores \\
\hline VMEM & L1 Cache \\
\hline ICl & NVLink/NVSwitch \\
\hline
\end{tabular}
- Compared to TPUs, GPUs have many more 'streaming multiprocessors' (an H100 has about 140), each of which can be seen as analogous to a TensorCore (which a TPU only has 1-2 of). Having more SMs makes computation more flexible (since each can do totally independent work) but also makes the hardware more complex to reason about.
- Each SM in an H100 has about 1024 CUDA Cores which perform SIMD scalar work (like a TPU VPU) and a small L1 cache used to speed data access and for register spilling. A section of the memory used for the L1 cache can also be declared as shared memory allowing access from any thread in the thread-block, and is used for user-defined caches, parallel reductions and synchronization, etc (similar to VMEM on a TPU).
- GPUs also have an additional L2 cache that is shared by all SMs. Unlike VMEM, this is hardware managed and optimizing cache hits is often important for perfomrance.
Networking:
- Primary difference is that NVIDIA GPUs are typically in 'cliques' of 8-256 GPUs via switches (NVLink $\rightarrow$ NVSwitch), which allow for point-to-point communication between any GPU within that 'clique', but that means communication between more than 256 is significantly slower - this means training on more than 256 typically requires pipeline parallelism to scale, which is more complex (by contrast, PaLM was trained on two cliques of 3072 TPU chips each).
- For common neural net operations such as AllReduce, all-to-all connections do not hold an advantage (as the same communication patterns must occur regardless), but it does allow for storing MoE models across more GPUs and transmitting the experts around more efficiently.
- Each GPU requires a switch that costs similar to the GPU itself, making on chip interconnect like ICI cheaper.
- NVIDIA deep learning performance
- NVSwitch
- Very different Tensor Parallelism / Pipeline Parallelism transition point!
Appendix B: How does a systolic array work?
At the core of the TPU MXU is a $128 \times 128$ systolic array ( $256 \times 256$ on TPU v6e). When fully saturated the systolic array can perform one bfloat16 [8,128] @ bf16 [128 $\times 128$ ] -> f32 [8,128] ${ }^{10}$ multiplication per 8 clock cycles.
- At its core, the systolic array is a 2D $128 \times 128(=16,384)$ grid of ALUs each capable of performing a multiply and add operation.
- Weights ( $\mathbf{W}$, the $128 \times 128$ input) are passed down from above (called the RHS) while inputs ( $\mathbf{X}$, the $8 \times 128$ input) are passed in from the left (called the LHS).
Here is a simplified animation of multiplying a set of weights (blue) with a set of activations (green). You'll notice that the weights (RHS) are partially loaded first, diagonally, and then the activations are fed in, also diagonally. In each frame below, we multiply all the overlapped green and blue units, sum the result with any residual passed in from above, and then pass the result in turn down one unit.
There is an initial pipeline bubble as the weights (RHS) and activations (LHS) are loaded. After that initial bubble, new inputs and weights can be loaded in without an additional bubble.
Here’s a bad animation of a bf16[2, 3] x bf16[3, 3] matrix multiplication, which you could imagine as a matmul of a 2x3 weight matrix with an input activation of batch 1 and size 3. This is rotated compared to the previous slides and inputs flow out to the right instead of down, but you can roughly see the structure.
We can efficiently pipeline this to multiply large matrices without too large a pipeline bubble. With that said, it’s important that our matrices have shapes larger than the side dimension of the MXU, which is generally 128x128. Some TPUs (since TPU v3) have multiple MXUs, either 2 for TPU v3 and 4 for TPU v4/5, so we need to ensure tiling dimensions are larger than 128 * number of MXUs. Here’s a good animation for this.
Trillium (TPU v6e) has a 256x256 systolic array, which means it can perform 4x more FLOPs / cycle. This also means the dimensions of your tensors needs to be twice as large to utilize the MXU fully.
This blog post has another excellent animation of a systolic array multiplication for a fixed weight matrix.
Appendix C: TPU internals
Appendix C: TPU internals
Scalar Core
The TPU scalar core processes all of the instructions and executes all of the transfers from HBM into vector memory (VMEM). The scalar core is also responsible for fetching instructions for the VPU, MXU and XLU components of the chip. One side-effect of this is that each core of the TPU is only capable of creating one DMA request per cycle.
To put this in context, a single 4 scalar core controls a VPU consisting of 2048 ALUs, 4 MXUs, 2 XLUs, and multiple DMA engines. The highly skewed nature of control per unit compute is a source of hardware efficiency, but also limits the ability to do data dependent vectorization in any interesting way.
VPU
The TPU vector core consists of a two dimensional vector machine (the VPU) that performs vector operations like vadd (vector addition) or vmax (elementwise max) and a set of vector registers called VREGs that hold data for the VPU and MXU. The VPU is effectively a 2D vector arithmetic unit of shape $(8,128)$ where the 128 dimension is referred to as a lane and the dimension of 8 is referred to as a sublane. Each (lane, sublane) pair on v4 contains 2 standard floating-point and integer ALUs. From a software point-of-view, this creates the appearance of a $8 \times 128$ vector unit with a total of 2048 floating point adders in v4. TPU v4 has 32 VREGs of size $(8,128)$ which the VPU loads from and writes to.
The VPU executes most arithmetic instructions in one cycle in each of its ALUs (like vadd or vector add) with a latency of 2 cycles, so e.g. in v5 you can add 4 pairs of f 32 values together from VREGs in each cycle. A typical VPU instruction might look like $\{\mathrm{v} 2=$ vadd. $8 \times 128$. f32 $\mathrm{v} 0, \mathrm{v} 1\}$ where v 0 and v 1 are input VREGs and v 2 is an output VREG.
All lanes and sublanes execute the same program every cycle in a pure SIMD manner, but each ALU can perform a different operation. So we can e.g. process 1 vadd and 1 vsub in a single cycle, each of which operates on two full VREGs and writes the output to a third.
```
- Part 3: All the trasnformer math you need to know
```jsx
All the Transformer Math You Need to Know
Part 4 of How To Scale Your Model (Part 3: Sharding I Part 5: Training)
Here we'll do a quick review of the Transformer architecture, specifically how to calculate FLOPs, bytes, and other quantities of interest.
Counting Dots
Let's start with vectors $\boldsymbol{x}, \boldsymbol{y}$ and matrices $\boldsymbol{A}, \boldsymbol{B}$ of the following shapes:
\begin{tabular}{cc}
array & shape \\
\hline$x$ & {$[\mathrm{P}]$} \\
$y$ & {$[\mathrm{P}]$} \\
$A$ & {$[\mathrm{~N} \mathrm{P}]$} \\
$B$ & {$[\mathrm{P} \mathrm{M}]$} \\
\hline
\end{tabular}
- A dot product of $\boldsymbol{x} \cdot \boldsymbol{y}$ requires $\boldsymbol{P}$ adds and multiplies, or $2 \boldsymbol{P}$ floating-point operations total.
- A matrix-vector product $A x$ does $N$ dot-products along the rows of $A$, for $2 N P$ FLOPs.
- A matrix-matrix product $A B$ does $M$ matrix-vector products for each column of $B$, for $2 N P M$ FLOPs total.
- In general, if we have two higher dimensional arrays $C$ and $D$, where some dimensions are CONTRACTING and some are BATCHING. (e.g. $C[G H I J K L], D[G H M N K L]$ ) then the FLOPs cost of this contraction is two times the product of all of the $C$ and $D$ dimensions where the batch and contraction dimensions are only counted once, (e.g. 2 GHIJMNKL ). Note that a dimension is only batching if it occurs in both multiplicands. (Note also that the factor of 2 won't apply if there are no contracting dimensions and this is just an elementwise product.)
\begin{tabular}{ccc}
Operation & FLOPs & Data \\
\hline$x \cdot y$ & $2 P$ & $2 P$ \\
$A x$ & $2 N P$ & $N P+P$ \\
$A B$ & $2 N P M$ & $N P+P M$ \\
{$\left[c_0, \ldots, c_N\right] \cdot\left[d_0, \ldots, d_N\right]$} & $2 \prod c_i \times \prod_{\substack{d_j \notin B A T C H \\
d_j \notin C O N T R A C T}} d_j$ & $\prod c_i+\prod d_j$ \\
\hline
\end{tabular}
Make note of the fact that for a matrix-matrix multiply, the compute scales cubically $O\left(N^3\right)$ while the data transfer only scales quadratically $O\left(N^2\right)$ - this means that as we scale up our matmul size, it becomes easier to hit the compute-saturated limit. This is extremely unusual, and explains in large part why we use architectures dominated by matrix multiplication - they're amenable to being scaled!
Forward and reverse FLOPs
During training, we don't particularly care about the result of a given matrix multiply; we really care about its derivative. That means we do significantly more FLOPs during backpropagation.
If we imagine $\mathbf{B}$ is just one matrix in a larger network and $\mathbf{A}$ are our input activations with $\mathbf{C}=\mathbf{A}$ $\mathbf{B}$, the derivative of the loss $\mathbf{L}$ with respect to $\mathbf{B}$ is given by the chain rule:
$$
\frac{\partial L}{\partial B}=\frac{\partial L}{\partial C} \frac{\partial C}{\partial B}=A^T\left(\frac{\partial L}{\partial C}\right)
$$
which is an outer product and requires $2 N P M$ FLOPs to compute (since it contracts over the $N$ dimension). Likewise, the derivative of the loss with respect to $\mathbf{A}$ is
$$
\frac{\partial L}{\partial A}=\frac{\partial L}{\partial C} \frac{\partial C}{\partial A}=\left(\frac{\partial L}{\partial C}\right) B^T
$$
is again $2 N P M$ FLOPs since $\mathbf{d L} / \mathbf{d C}$ is a (co-)vector of size $[N, M]$. While this quantity isn't the derivative wrt. a parameter, it's used to compute derivatives for previous layers of the network (e.g. just as $\mathrm{dL} / \mathrm{dC}$ is used to compute $\mathrm{dL} / \mathrm{dB}$ above).
Adding these up, we see that during training, we have a total of 6NPM FLOPs, compared to 2NPM during inference: 2NPM in the forward pass, 4NPM in the backward pass. Since PM is the number of parameters in the matrix, this is the simplest form of the famous 6 * num parameters * num tokens approximation of Transformer FLOPs during training: each token requires 6 * num parameters FLOPs. We'll show a more correct derivation below.
Transformer Accounting
Transformers are the future. Well, they're the present at least. Maybe a few years ago, they were one of many architectures. But today, it's worth knowing pretty much every detail of the architecture. We won't reintroduce the architecture but this blog and the original Transformer paper may be helpful references.
Here’s a basic diagram of the Transformer decoder architecture:
Figure: this diagram shows one layer of a standard Transformer and flows from top-to-bottom. We use a singleletter convention to describe the shapes and layouts of arrays in a Transformer, again showing contracting dimensions in red, and batched dimensions in blue. In a given operation, the input shape is given on top-left and the parameter shape is given on the top-right, with the resulting shape below, e.g. BTD is the input shape for the gating einsum and DF is the weight shape.
Note [gating einsum]: The diagram above uses a "gating einsums" [1] where we split the upprojection matrix into two matrices ( $W_{\mathrm{In} 1}$ and $W_{\mathrm{In} 2}$ above) whose outputs are elementwise multiplied as a kind of "gating function". Not all LLMs use this, so you will sometimes see a single $W_{\text {In }}$ matrix and a total MLP parameter count of 2DF instead of 3DF. Typically in this case, $D$ and $F$ will be scaled up to keep the parameter count the same as the 3 matrix case. With that said, some form of gating einsum is used by LLAMA, DeepSeek, and many other models.
Note 2 [MHA attention]: With self-attention, $T$ and $S$ are the same but for cross-attention they may be different. With vanilla Multi-Head Attention (MHA), N and K are the same while for MultiQuery Attention (MQA) [2] K=1 and for Grouped MQA (GMQA) [3] K merely has to divide N.
Global FLOPs and Params Calculation
For the below we're going to compute per-layer FLOPs to avoid having to stick factors of L everywhere.
MLPs
The MLPs of a Transformer typically consist of 2 input matmuls that are element-wise combined and a single output matmul:
\begin{tabular}{|l|l|l|}
\hline operation & train FLOPs & params \\
\hline $A[B, T, D] \cdot W_{i n 1}[D, F]$ & $6 B T D F$ & DF \\
\hline $A[B, T, D] \cdot W_{i n 2}[D, F]$ & $6 B T D F$ & DF \\
\hline $\sigma\left(A_{\text {in1 }}\right)[B, T, F] * A_{\text {in2 }}[B, T, F]$ & $O(B T F)$ & \\
\hline $A[B, T, F] \cdot W_{\text {out }}[F, D]$ & $6 B T D F$ & $D F$ \\
\hline & $\approx 18 B T D F$ & $3 D F$ \\
\hline
\end{tabular}
Attention
For the generic grouped-query attention case with different $\mathbf{Q}$ and $\mathbf{K V}$ head numbers, let us assume equal head dimension H for $\mathbf{Q}, \mathbf{K}, \mathbf{V}$ projections, and estimate the cost of the $\mathbf{Q K V O}$ matmuls:
\begin{tabular}{|l|l|l|}
\hline operation & train FLOPs & params \\
\hline $A[B, T, D] \cdot W_Q[D, N, H]$ & $6 B T D N H$ & DNH \\
\hline $A[B, T, D] \cdot W_K[D, K, H]$ & $6 B T D K H$ & DKH \\
\hline $A[B, T, D] \cdot W_V[D, K, H]$ & $6 B T D K H$ & DKH \\
\hline $A[B, T, N, H] \cdot W_O[N, H, D]$ & $6 B T D N H$ & DNH \\
\hline
\end{tabular}
$$
12 B T D(N+K) H \quad 2 D(N+K) H
$$
The dot-product attention operation is more subtle, effectively being a $T H \cdot H S$ matmul batched over the $B, K$ dimensions, a softmax, and a $T S \cdot S H$ matmul again batched over the $B, K$ dimensions. We highlight the batched dims in blue:
$$
\begin{array}{cc}
\text { operation } & \text { train FLOPs } \\
\hline & \\
Q[B, T, K, G, H] \cdot K[B, S, K, H] & 6 B T S K G H=6 B T S N H \\
\operatorname{softmax}_S L[B, T, S, K, G] & O(B T S K G)=O(B T S N) \\
S[B, T, S, K, G] \cdot V[B, S, K, H] & 6 B T S K G H=6 B T S N H \\
\hline & \approx 12 B T S N H=12 B T^2 N H
\end{array}
$$
Other Operations
There are several other operations happening in a Transformer. Layernorms are comparatively cheap and can be ignored for first-order cost estimates. There is also the final enormous (though not per-layer) unembedding matrix multiply.
\begin{tabular}{ccc}
operation & train FLOPs & params \\
\hline layernorm $_D A[B, T, D]$ & $O(B T D)$ & $D$ \\
$A[B, T, D] \cdot W_{\text {unembed }}[D, V]$ & $6 B T D V$ & $D V$
\end{tabular}
General rule of thumb for Transformer FLOPs
If we neglect the cost of dot-product attention for shorter-context training, then the total FLOPs across all layers is
$$
\begin{aligned}
(18 B T D F+12 B T D(N+K) H) L & =6 * B T *(3 D F+2 D(N+K) H) L \\
& =6 * \text { num tokens } * \text { parameter count }
\end{aligned}
$$
Leading to a famous rule of thumb for estimating dense Transformer FLOP count, ignoring the attention FLOPs. (Unembedding is another simple matmul with $6 B S D V$ FLOPs and $D V$ params, and follows the same rule of thumb.)
Fractional cost of attention with context length
If we do account for dot-product attention above and assume $F=4 D, D=N H$ (as is typical) and $N=K$ :
$$
\frac{\text { attention FLOPs }}{\text { matmul FLOPs }}=\frac{12 B T^2 N H}{18 B T D F+24 B T D N H}=\frac{12 B T^2 D}{4 * 18 B T D^2+24 B T D^2}=\frac{12 B T^2 D}{96 B T D^2}=\frac{T}{8 D}
$$
So the takeaway is that dot-product attention FLOPs only become dominant during training once $\mathbf{T} \boldsymbol{>} \mathbf{8 D}$. For $\mathrm{D} \sim 8 \mathrm{k}$, this would be $\sim 64 \mathrm{~K}$ tokens. This makes some sense, since it means as the MLP size increases, the attention FLOPs become less critical. For large models, the quadratic cost of attention is not actually a huge obstacle to longer context training. However, for smaller models, even e.g. Gemma-27B, $D=4608$ which means attention becomes dominant around 32 k sequence lengths. Flash Attention also helps alleviate the cost of long-context, which we discuss briefly in Appendix A.
Miscellaneous Math
Sparsity and Mixture-of-Experts
We'd be remiss not to briefly discuss Mixture of Experts (MoE) models [4] , which replace the single dense MLP blocks in a standard Transformer with a set of independent MLPs that can be dynamically routed between. To a first approximation, an MoE is just a normal dense model with E MLP blocks per layer, instead of just one. Each token activates $k$ of these experts, typically $k=2$. This increases the parameter count by $O(E)$, while multiplying the total number of activated parameters per token by $k$, compared with the dense version.
Figure: an example MoE layer with $n$ experts. The gating expert routes each token to $k$ of them, and the output of those $k$ MLPs get summed. Our parameter count is $n$ times the size of each expert, but only $k$ are used for each token. Source.
Compared to a dense model, an MoE introduces new comms, primarily two AllToAlls (one before and one after the MoE block) that route tokens to the correct expert and bring them back to their home device. ${ }^1$ However as we saw in the previous section, the cost of each AllToAll is only $1 / 4$ that of a comparable AllGather along a single axis (for a bidirectional ring).
Gradient checkpointing
Backpropagation as an algorithm trades memory for compute. Instead of a backward pass requiring $O\left(n_{\text {layers }}^2\right)$ FLOPs, it requires $O\left(n_{\text {layers }}\right)$ memory, saving all intermediate activations generated during the forward pass. While this is better than quadratic compute, it's incredibly expensive memory-wise: a model with $B * T=4 M$ (4M total tokens per batch), $\mathrm{L}=64$, and D=8192 that avoids all unnecessary backward pass compute would have to save roughly $2 * 20 * B * T * D * L=84 T B$ of activations in bfloat16. 20 comes from (roughly) counting every intermediate node in the Transformer diagram above, since e.g.
$$
\begin{gathered}
f(x)=\exp (g(x)) \\
\frac{d f}{d x}=\exp (g(x)) \cdot \frac{d g}{d x}
\end{gathered}
$$
so to avoid recomputing we need to save $\boldsymbol{g}(\boldsymbol{x})$ and $\exp (\boldsymbol{g}(\boldsymbol{x}))$ from the forward pass. To avoid saving this much memory, we can choose to only save some fraction of the intermediate activations. Here are a few strategies we use.
- Block remat: only save the input to each layer. This is the most aggressive method we use and only saves 1 checkpoint per layer, meaning we'd only save 4.2TB in the example above. This forces us to repeat essentially all forward pass FLOPs in the backward pass, meaning we increase our FLOPs from $6 N D$ to roughly $8 N D$.
- Big matmuls only: another simple policy is to only save the outputs of large matmuls. This lets us avoid recomputing any large matmuls during the backward pass, but still makes us recompute other activation functions and parts of attention. This reduces 20 per layer to closer to 7 per layer.
This by no means comprehensive. When using JAX, these are typically controlled by jax. remat / jax. checkpoint (you can read more here).
Key-Value (KV) caching
As we'll see in Section 7, LLM inference has two key parts, prefill and generation.
- Prefill processes a long prompt and saves its attention activations in a Key-Value Cache (KV Cache) for use in generation, specifically the key-value projections in the attention block.
- Generation batches several of these KV caches together and samples tokens from each of them.
Each KV cache is then effectively an array of size $[2, S, L, K, H]$ where the 2 accounts for the keys and values. This is quite large! The total size of the Key-Value cache in int8 is $2 S L K H$.
For a moderately-sized model with 8 k context length, 64 layers, and $K H=N H=D=8192$, this is $2 \cdot 8192 \cdot 64 \cdot 8192=8 \mathrm{GiB}$. You can see why we would want to use GMQA with $K \ll N$.
What Should You Take Away from this Section?
- The overall parameters and FLOPs of a Transformer are fairly easy to calculate, and are summarized here, assuming MHA (with batch size B, vocab size V, a sequence of length T, $\mathrm{D}=\mathrm{d}_{\text {model }}$, and $\mathrm{F}=\mathrm{d}_{\mathrm{ff}}$ ):
\begin{tabular}{|l|l|l|}
\hline Component & Params per layer & Training FLOPs per layer \\
\hline MLP & 3DF & 18BTDF \\
\hline Attention & 4DNH & 24BTDNH + 12BT2NH \\
\hline Other & D & BTD \\
\hline Vocab & DV (total, not per-layer) & 12BTDV \\
\hline
\end{tabular}
- The parameter count of the MLP block dominates the total parameter count and the MLP block also dominates the FLOPs budget as long as the sequence length $T<8 D$.
- The total FLOPs budget during training is well approximated by $6 \cdot$ num_params $\cdot$ num_tokens for reasonable context lengths.
- During inference, our KV caches are roughly $2 \cdot S \cdot L \cdot N \cdot H$ per cache, although architectural modifications can often reduce this.
A Few Problems to Work
Question 1: How many parameters does a model with $D=4096, F=4 \cdot D, V=32,000$, and $L=64$ have? What fraction of these are attention parameters? How large are our KV caches per token? You can assume $N \cdot H=D$ and multi-head attention with int8 KVs.
click here tor the answer.
1. The total parameters is roughly $L \cdot(3 D F+4 D N H+D)+2 D V$. For the given numbers, this is
$$
64 \cdot(3 \cdot 4 e 3 \cdot 16 e 3+4 \cdot 4 e 3 \cdot 4 e 3+4 e 3)+2 \cdot 4 e 3 \cdot 32 e 3=16 e 9, \text { or } 16 \text { В }
$$
parameters.
2. The ratio of attention parameters to total parameters in general is $4 D N H /(4 D N H+3 D F)=4 D^2 /\left(4 D^2+12 D^2\right)=1 / 4$. This gives us roughly $1 / 4$ of parameters are used in attention.
3. Per token, our KV caches are $2 \cdot L \cdot N \cdot H=2 \cdot 64 \cdot 4096$ in int8, which is 512 kB / token.
Question 2: How many total FLOPs are required to perform $A\left[B_X, D_Y\right]{ }_D W\left[D_Y, F\right]$ on $\left\{{ }^{\prime} X\right.$ ' : 4 , ' $Y$ ': 8, ' $Z$ ': 4 \}. How many FLOPs are performed by each TPU?
Click here for the answer.
The total "theoretical" FLOPs of the operation is $2 \cdot B \cdot D \cdot F$. However, because the computation isn't sharded across the $Z$ dimension, we're actually doing $Z$ extra FLOPs, meaning $2 \cdot B \cdot D \cdot F \cdot Z$ total FLOPs. Since the computation is sharded across the other dimensions, the total per-device is roughly $2 \cdot B \cdot D \cdot F /(X \cdot Y)$.
Question 3: How many FLOPs are involved in performing
$$
A[I, J, K, L] * B[I, J, M, N, O] \rightarrow C[K, L, M, N, O] ?
$$
Following the rule above, we have I and J as contracting dimensions and $\mathrm{K}, \mathrm{L}, \mathrm{M}, \mathrm{N}$, and O as non-contracting dimensions. We have no "batching dimensions", so this is just $\mathbf{2} \cdot \boldsymbol{I} \cdot \boldsymbol{J} \cdot \boldsymbol{K} \cdot \boldsymbol{L} \cdot \boldsymbol{M} \cdot \boldsymbol{N} \cdot \boldsymbol{O}$, the sum of all the axes. If we had a shared axis, it would only be counted once.
Question 4: What is the arithmetic intensity of self-attention (ignoring the Q/K/V/O projections)? Give the answer as a function of the $Q$ and $K V$ lengths $T$ and $S$. At what context length is attention FLOPs-bound? Given the HBM bandwidth of our TPUs, plot the effective relative cost of attention to the FFW block as the context length grows.
Self-attention requires loading the $Q, K$, and $V$ activations, then computing $\operatorname{softmax}(Q \cdot K) \cdot V$, then writing the result back to HBM. This will be done with Flash Attention so there are some caveats to this math, but basically in bf16 self-attention performs
$$
\begin{gathered}
\mathrm{Q}[\mathrm{~B}, \mathrm{~T}, \mathrm{~N}, \mathrm{H}] \rightarrow_{\text {reshape }} \mathrm{Q}[\mathrm{~B}, \mathrm{~T}, \mathrm{~K}, \mathrm{G}, \mathrm{H}] \cdot \mathrm{K}[\mathrm{~B}, \mathrm{~S}, \mathrm{~K}, \mathrm{H}] \rightarrow \mathrm{O}[\mathrm{~B}, \mathrm{~T}, \mathrm{~S}, \mathrm{~K}, \mathrm{G}] \\
U=\operatorname{softmax}_S(\mathrm{O}[\mathrm{~B}, \mathrm{~T}, \mathrm{~S}, \mathrm{~K}, \mathrm{G}]) \\
\mathrm{U}[\mathrm{~B}, \mathrm{~T}, \mathrm{~S}, \mathrm{~K}, \mathrm{G}] \cdot \mathrm{V}[\mathrm{~B}, \mathrm{~S}, \mathrm{~K}, \mathrm{H}] \rightarrow \mathrm{X}[\mathrm{~B}, \mathrm{~T}, \mathrm{~K}, \mathrm{G}, \mathrm{H}]
\end{gathered}
$$
So our total bytes is
$$
2 * \operatorname{sizeof}(Q)+2 * \operatorname{sizeof}(\mathrm{~K} \text { or } \mathrm{V})=4 B T N H+4 B S K H=4 B H K *(T G+S)
$$
, total FLOPs is $4 B T S N H+O(B T S N)$ and the arithmetic intensity is $4 B T S K G H /(4 B H K *(T G+S))$.
So basically, during prefill we have $S=T$ so we have an arithmetic intensity of $4 B T^2 K G H / 4 B H K T \cdot(G+1)=T G /(G+1)=O(T)$. During generation, $T=1$ so we have $4 B S K G H /(4 B H K \cdot(G+S))=S G /(G+S) \rightarrow G$ assuming $S$ is very large. Depending on how you interpret the question, during prefill or training self-attention is compute bound at $\mathrm{S}=240$ assuming no sequence sharding. During generation, we are never compute bound because $\boldsymbol{G}$ is small. Nonetheless, however, you can see that increasing $\boldsymbol{G}$ leads to us being closer to compute bound.
Question 5: At what sequence length are self-attention FLOPs equal to the QKVO projection FLOPs?
Click here for the answer.
This is purely a question of when $24 B T D N H==12 B T^2 N H$. Simplifying we get $2 D=T$, so e.g. for $D=4096$, this is 8192 . This tells us that for most reasonable context lengths, matmul FLOPs are greater.
Question 6: Say we only save the output of each of the 7 main matmuls in a Transformer layer during our forward pass ( $\mathrm{Q}, \mathrm{K}, \mathrm{V}, \mathrm{O}+$ the three FFW matrices). How many extra FLOPs do we need to "rematerialize" during the backwards pass?
Question 7: DeepSeek v3 says it was trained for 2.79 M H 800 hours on 14.8 T tokens (source). Given that it has 37B activated parameters, roughly what hardware utilization did they achieve?
Hint: note that they used FP8 FLOPs without structured sparsity.
- Click here for the answer.
From the spec sheet here, we find 3,026 TFLOPs/s of FP8 performance with sparsity, or typically half this ( $1.513 \mathrm{e} 15 \mathrm{FLOPs} / \mathrm{s}$ ) without sparsity. 2.79 M H 800 hours means 2.79 e 6 * 1.513e15*60*60 = 1.52 e 25 total FLOPs. Given the activated parameter count of 37B, this training run should have used about $6 * 37 \mathrm{e} 9 * 14.8 \mathrm{e} 12=3.3 \mathrm{e} 24$ FLOPs. That means the FLOPs utilization is about $3.3 \mathrm{e} 24 / 1.52 \mathrm{e} 25=21.7 \%$.
Question 8: Mixture of Experts (MoE) models have $E$ copies of a standard dense MLP block, and each token activates $k$ of these experts. What batch size in tokens is required to be compute-bound for an MoE with weights in int8 on TPU v5e? For DeepSeek, which has 256 (routed) experts and $k=8$, what is this number?
Click here for the answer.
Because we have $E$ copies of each expert, in int8, we need to load $E \cdot D \cdot F$ bytes.
Because each token activates $k$ experts, we have $2 \cdot k \cdot B \cdot D \cdot F$ FLOPs. To be computebound with bfloat16 FLOPs, we need an arithmetic intensity over 240 which happens when $(2 \cdot k \cdot B D F) / E D F>240$ or $k \cdot B / E>120$.
Therefore, we need $B>120 \cdot E / k$ to be compute bound. For DeepSeek, this gives us $B>120 \cdot 256 / 8=3840$. This is a remarkably large batch size at generation time.
```
- Part 4: **How to Parallelize a Transformer for Training**
```jsx
Here we discuss four main parallelism schemes used during LLM training: data parallelism, fully-sharded data parallelism (FSDP), tensor parallelism, and pipeline parallelism. For each, we calculate at what point we become bottlenecked by communication.
What Do We Mean By Scaling?
The goal of "model scaling" is to be able to increase the number of chips used for training or inference while achieving a proportional, linear increase in throughput (we call this strong scaling). While performance on a single chip depends on the trade-off between memory bandwidth and FLOPs, performance at the cluster level depends on hiding inter-chip communication by overlapping it with useful FLOPS. This is non-trivial, because increasing the number of chips increases the communication load while reducing the amount of per-device computation we can use to hide it. As we saw in Section 3, sharded matrix multiplications often require expensive AllGathers or ReduceScatters that can block the TPUs from doing useful work. The goal of this section is to find out when these become too expensive.
In this section, we'll discuss four common parallelism schemes: (pure) data parallelism, fullysharded data parallelism (FSDP / ZeRO sharding), tensor parallelism (also known as model parallelism), and (briefly) pipeline parallelism. For each, we'll show what communication cost we incur and at what point that cost starts to bottleneck our compute cost. ${ }^1$ For this section, you can focus solely on inter-chip communication costs, since as long as we have a large enough single-chip batch size, the transfer of data from HBM to MXU is already overlapped with computation.
We'll use the following notation to simplify calculations throughout this section.
\begin{tabular}{|l|l|}
\hline Notation & Meaning (model parameters) \\
\hline D & $\mathbf{d}_{\text {model }}$ ( the hidden dimension/residual stream dim) \\
\hline F & $\mathbf{d}_{\text {ff }}$ (the feed-forward dimension) \\
\hline B & Batch dimension (number of tokens in the batch; total, not per-device) \\
\hline T & Sequence length \\
\hline L & Number of layers in the model \\
\hline Notation & Meaning (hardware characteristic) \\
\hline C & FLOPS/s per chip \\
\hline W & Network bandwidth (bidirectional, often subscripted as e.g. $W_{\text {ici }}$ or $W_{\text {dcn }}$ \\
\hline X & Number of chips along mesh axis X \\
\hline Y & Number of chips along an alternate mesh axis, labeled Y \\
\hline Z & Number of chips along a third mesh axis, labeled Z \\
\hline
\end{tabular}
For simplicity's sake, we'll approximate a Transformer as a stack of MLP blocks - attention is a comparatively small fraction of the FLOPs for larger models as we saw in Section 4. We will also ignore the gating matmul, leaving us with the following simple structure for each layer:
Figure: a simplified Transformer layer. We treat each FFW block as a stack of two matrices $\mathbf{W}_{\text {in }}$ : bf16 [D, F] (upprojection) and $\mathrm{W}_{\text {out }}: \mathrm{bf} 16[\mathrm{~F}, \mathrm{D}]$ (down-projection) with an input $\mathrm{In}: \mathrm{bf} 16[\mathrm{~B}, \mathrm{D}]$.
Here are the 4 parallelism schemes we will discuss. Each scheme can be thought of as uniquely defined by a sharding for $\mathbf{I n}, \mathbf{W}_{\text {in }}, \mathbf{W}_{\text {out }}$, and Out in the above diagram.
1. Data parallelism: activations sharded along batch, parameters and optimizer state are replicated on each device. Communication only occurs during the backwards pass.
$$
\operatorname{In}\left[B_X, D\right] \cdot{ }_D W_{\text {in }}[D, F] \cdot{ }_F W_{\text {out }}[F, D] \rightarrow \operatorname{Out}\left[B_X, D\right]
$$
2. Fully-sharded data parallelism (FSDP or ZeRO-3): activations sharded along batch (like pure data parallelism), parameters sharded along same mesh axis and AllGathered just-in-time before use in forward pass. Optimizer state also sharded along batch. Reduces duplicated memory.
$$
\operatorname{In}\left[B_X, D\right] \cdot{ }_D W_{\text {in }}\left[D_X, F\right] \cdot{ }_F W_{\text {out }}\left[F, D_X\right] \rightarrow \operatorname{Out}\left[B_X, D\right]
$$
3. Tensor parallelism (also called Megatron sharding or model parallelism): activations sharded along $D\left(d_{\text {model }}\right)$, parameters sharded along $F\left(d_{f f}\right)$. AllGather and ReduceScatter activations before and after each block. Compatible with FSDP.
$$
\operatorname{In}\left[B, D_Y\right] \cdot{ }_D W_{\text {in }}\left[D, F_Y\right] \cdot{ }_F W_{\text {out }}\left[F_Y, D\right] \rightarrow \operatorname{Out}\left[B, D_Y\right]
$$
4. Pipeline parallelism: weights sharded along the layer dimension, activations microbatched and rolled along the layer dimension. Communication between pipeline stages is minimal (just moving activations over a single hop). To abuse notation:
$$
\operatorname{In}\left[L_Z, B, D\right][i] \cdot{ }_D W_{\text {in }}\left[L_Z, D, F\right][i] \cdot{ }_F W_{\text {out }}\left[L_Z, F, D\right][i] \rightarrow \operatorname{Out}\left[L_Z, B, D_Y\right][i]
$$
Data Parallelism
Syntax: $\operatorname{In}\left[B_X, D\right] \cdot{ }_D W_{\text {in }}[D, F] \cdot{ }_F W_{\text {out }}[F, D] \rightarrow \operatorname{Out}\left[B_X, D\right]$
When your model fits on a single chip with even a tiny batch size ( $>240$ tokens, so as to be compute-bound), you should always use simple data parallelism. Pure data parallelism splits our activations across any number of TPUs so long as the number of TPUs is smaller than our batch size. The forward pass involves no communication, but at the end of every step, each performs an AllReduce on their gradients in order to synchronize them before updating the parameters.
Figure: a diagram of pure data parallelism (forward pass). Our activations (left) are fully sharded along the batch dimension and our weights are fully replicated, so each TPU has an identical copy of the weights. This means the total memory of our weights is increased by a factor of N , but no communication is required on the forward-pass.
Here's the full algorithm for the forward and backwards pass. We abuse notation to write dL/dOut as dOut, purely for compactness.
Pure Data Parallelism Algorithm:
Forward pass: need to compute Loss $\left[\mathrm{B}_\chi\right]$
1. $\operatorname{Tmp}\left[\mathrm{B}_{\mathrm{X}}, \mathrm{F}\right]=\operatorname{In}\left[\mathrm{B}_{\mathrm{X}}, \mathrm{D}\right]{ }_{\mathrm{D}} \mathrm{W}_{\text {in }}[\mathrm{D}, \mathrm{F}]$
2. Out $\left[\mathrm{B}_{\mathrm{X}}, \mathrm{D}\right]=\operatorname{Tmp}\left[\mathrm{B}_{\mathrm{X}}, \mathrm{F}\right]{ }^*{ }_{\mathrm{F}} \mathrm{W}_{\text {out }}[\mathrm{F}, \mathrm{D}]$
3. $\operatorname{Loss}\left[\mathrm{B}_{\mathrm{X}}\right]=\ldots$
Backward pass: need to compute $\mathrm{dW}_{\text {out }}[\mathrm{F}, \mathrm{D}], \mathrm{dW}_{\text {in }}[\mathrm{D}, \mathrm{F}]$
1. $\operatorname{dOut}\left[\mathrm{B}_{\mathrm{X}}, \mathrm{D}\right]=\ldots$
2. $\mathrm{dW}_{\text {out }}[F, \mathrm{D}]\left\{\mathrm{U}_{\mathrm{X}}\right\}=\operatorname{Tmp}\left[\mathrm{B}_{\mathrm{X}}, F\right]{ }_{\mathrm{B}} \operatorname{dOut}\left[\mathrm{B}_{\mathrm{X}}, \mathrm{D}\right]$
3. $\mathrm{dW}_{\text {out }}[\mathrm{F}, \mathrm{D}]=$ AllReduce $\left(\mathrm{dW}_{\text {out }}[\mathrm{F}, \mathrm{D}]\left\{\mathrm{U}_\chi\right\}\right)$ (not on critical path, can be done async)
4. $\mathrm{dTmp}\left[\mathrm{B}_{\mathrm{X}}, \mathrm{F}\right]=\mathrm{dOut}\left[\mathrm{B}_{\mathrm{X}}, \mathrm{D}\right]{ }^*{ }_{\mathrm{D}} \mathrm{W}_{\text {out }}[F, \mathrm{D}]$
5. $d W_{\text {in }}[D, F]\left\{U_X\right\}=\ln \left[B_X, D\right]{ }_B d \operatorname{Tmp}\left[B_X, F\right]$
6. $d W_{\text {in }}[D, F]=$ AllReduce $\left(d W_{\text {in }}[D, F]\left\{U_\chi\right\}\right)$ (not on critical path, can be done async)
7. $\operatorname{dIn}\left[B_X, D\right]=d \operatorname{Tmp}\left[B_X, F\right]{ }^*{ }_F W_{\text {in }}[D, F]$ (needed for previous layers)
We ignore the details of the loss function and abbreviate $\operatorname{Tmp}=W_{\text {in }} \cdot \operatorname{In}$. Note that, although our final loss is the average AllReduce(Loss[ $\left.\mathrm{B}_{\mathrm{X}}\right]$ ), we only need to compute the AllReduce on the backward pass when averaging weight gradients.
Note that the forward pass has no communication - it's all in the backward pass! The backward pass also has the great property that the AllReduces aren't in the "critical path", meaning that each AllReduce can be performed whenever it's convenient and doesn't block you from performing subsequent operations. The overall communication cost can still bottleneck us if it exceeds our total compute cost, but it is much more forgiving from an implementation standpoint. We'll see that model/tensor parallelism doesn't have this property.
Why do this? Pure data parallelism reduces activation memory pressure by splitting our activations over the batch dimension, allowing us to almost arbitrarily increase batch size as long as we have more chips to split the batch dimension over. Especially during training when our activations often dominate our memory usage, this is very helpful.
Why not do this? Pure data parallelism does nothing to reduce memory pressure from model parameters or optimizer states, which means pure data parallelism is rarely useful for interesting models at scale where our parameters + optimizer state don't fit in a single TPU. To give a sense of scale, if we train with parameters in bf16 and optimizer state in fp32 with Adam ${ }^2$, the largest model we can fit has TPU memory/ 10 parameters, so e.g. on a TPUv5p pod with 96 GB of HBM and pure data parallelism this is about 9B parameters.
Takeaway: the largest model we can train with Adam and pure data parallelism has num_params $=$ HBM per device/10. For TPU v5p this is roughly 9B parameters. ${ }^3$
To make this useful for real models during training, we'll need to at least partly shard the model parameters or optimizer.
When do we become bottlenecked by communication? As we can see above, we have two AllReduces per layer, each of size $2 D F$ (for bf16 weights). When does data parallelism make us communication bound?
As in the table above, let $C=$ per-chip FLOPs, $W_{\text {ici }}=$ bidirectional network bandwidth, and $X=$ number of shards across which the batch is partitioned ${ }^4$. Let's calculate the time required to perform the relevant matmuls, $T_{\text {math }}$, and the required communication time $T_{\text {comms }}$. Since this parallelism scheme requires no communication in the forward pass, we only need to calculate these quantities for the backwards pass.
Communication time: From a previous section we know that the time required to perform an AllReduce in a 1D mesh depends only on the total bytes of the array being AllReduced and the ICI bandwidth $W_{\text {ici }}$; specifically the AllReduce time is $2 \cdot$ total bytes $/ W_{\text {ici }}$. Since we need to AllReduce for both $W_{\text {in }}$ and $W_{\text {out }}$, we have 2 AllReduces per layer. Each AllReduce is for a weight matrix, i.e. an array of $D F$ parameters, or $2 D F$ bytes. Putting this all together, the total time for the AllReduce in a single layer is
$$
T_{\mathrm{comms}}=\frac{2 \cdot 2 \cdot 2 \cdot D \cdot F}{W_{\mathrm{ici}}}
$$
Matmul time: Each layer comprises two matmuls in the forward pass, or four matmuls in the backwards pass, each of which requires $2(B / X) D F$ FLOPs. Thus, for a single layer in the backward pass, we have
$$
T_{\text {math }}=\frac{2 \cdot 2 \cdot 2 \cdot B \cdot D \cdot F}{X \cdot C}
$$
Since we overlap, the total time per layer is the max of these two quantities:
$$
\begin{aligned}
& T \approx \max \left(\frac{8 \cdot B \cdot D \cdot F}{X \cdot C}, \frac{8 \cdot D \cdot F}{W_{\mathrm{ici}}}\right) \\
& T \approx 8 \cdot D \cdot F \cdot \max \left(\frac{B}{X \cdot C}, \frac{1}{W_{\mathrm{ici}}}\right)
\end{aligned}
$$
We become compute-bound when $T_{\text {math }} / T_{\text {comms }}>1$, or when
$$
\frac{B}{X}>\frac{C}{W_{\mathrm{ici}}}
$$
The upshot is that, to remain compute-bound with data parallelism, we need the per-device batch size $B / X$ to exceed the ICI operational intensity, $C / W_{\text {ici }}$. This is ultimately a consequence of the fact that the computation time scales with the per-device batch size, while the communication time is independent of this quantity (since we are transferring model weights). Note the resemblance of the $B>C / W_{\text {ici }}$ condition to the single-device computebound rule $B>240$; in that case as well, the rule came from the fact that computation time scaled with batch size while data-transfer size was (in the $B \ll F, D$ regime) independent of batch size.
Let's put in some real numbers to get a sense of scale. For TPUv5p, $\mathrm{C}=4.6 \mathrm{e} 14$ and $\mathrm{W}=2$ * 9e10 for 1D data parallelism over ICI, so our batch size per chip must be at least 2,550 to avoid being communication-bound. Since we can do data parallelism over multiple axes, if we dedicate all three axes of a TPUv5p pod to pure data parallelism, we 3 x our bandwidth $W_{\text {ici }}$ and can scale down to only BS=850 per TPU or 7.6 M tokens per batch per pod (of 8960 chips )! This tells us that it's fairly hard to become bottlenecked by pure data parallelism!
Note on context parallelism: throughout this section, we use $B$ to refer to the total batch size in tokens. Clearly, however, our batch is made up of $K$ sequences of $T$ tokens each, so how can we do this? As far as the MLP is concerned, tokens are tokens! It doesn't matter if they belong to the same batch or two different batches. So we are more or less free to do data parallelism over both the batch and sequence dimension: we call this context parallelism or sequence parallelism, but you can think of it as simply being another kind of data parallelism. Attention is trickier than the MLP since we do some cross-sequence computation, but this can be handled by gathering KVs or Qs during attention and carefully overlapping FLOPs and comms (typically using something called "ring attention").
Throughout this section, we will just ignore our sequence dimension entirely and assume some amount of batch or sequence parallelism.
Fully-Sharded Data Parallelism (FSDP)
Syntax: $\operatorname{In}\left[B_X, D\right] \cdot{ }_D W_{\text {in }}\left[D_X, F\right] \cdot{ }_F W_{\text {out }}\left[F, D_X\right] \rightarrow \operatorname{Out}\left[B_X, D\right]$
Fully-sharded data parallelism (often called FSDP or ZeRO-sharding [1]) splits the model optimizer states and weights across the data parallel shards and efficiently gathers and scatters them as needed. Compared to pure data parallelism, FSDP drastically reduces perdevice memory usage and saves on backward pass FLOPs, with very minimal overhead.
Figure: FSDP shards the contracting dimension of Win and the output dimension of Wout along the data dimension. This reduces memory but (from Section 3) requires us to gather the weights for W before we perform the matmul. Note that the activations (left) are not sharded along the contracting dimension, which is what forces us to gather. Note that our weight optimizer state is likewise sharded along the contracting dimension.
You'll remember (from Section 3) that an AllReduce can be decomposed into an AllGather and a ReduceScatter. This means that, instead of doing the full gradient AllReduce for standard data parallelism, we can shard the weights and optimizer states across chips, AllGather them at each layer during the forward pass and ReduceScatter across the weights during the backward pass at no extra cost.
Here's the full algorithm for FSDP.
Fully-Sharded Data Parallelism (FSDP):
Forward pass: need to compute $\operatorname{Loss}\left[\mathrm{B}_{\mathrm{X}}\right]$
1. $\mathrm{W}_{\text {in }}[\mathrm{D}, \mathrm{F}]=$ AllGather $\left(\mathrm{W}_{\text {in }}\left[\mathrm{D}_{\mathrm{X}}, \mathrm{F}\right]\right)$ (not on critical path, can do it during previous layer)
2. $\operatorname{Tmp}\left[\mathrm{B}_{\mathrm{X}}, \mathrm{F}\right]=\operatorname{In}\left[\mathrm{B}_{\mathrm{X}}, \mathrm{D}\right]{ }^*{ }_{\mathrm{D}} \mathrm{W}_{\text {in }}[\mathrm{D}, \mathrm{F}]$ (can throw away $W_{\text {in }}[D, F]$ now)
3. $\mathrm{W}_{\text {out }}[\mathrm{F}, \mathrm{D}]=$ AllGather $\left(\mathrm{W}_{\text {out }}\left[\mathrm{F}_{,} \mathrm{D}_{\mathrm{X}}\right]\right)$ (not on critical path, can do it during previous layer)
4. $\operatorname{Out}\left[\mathrm{B}_{\mathrm{X}}, \mathrm{D}\right]=\operatorname{Tmp}\left[\mathrm{B}_{\mathrm{X}}, \mathrm{F}\right]{ }^*{ }_{\mathrm{F}} \mathrm{W}_{\text {out }}[\mathrm{F}, \mathrm{D}]$
5. $\operatorname{Loss}\left[\mathrm{B}_{\mathrm{X}}\right]=\ldots$
Backward pass: need to compute $\mathrm{dW}_{\text {out }}\left[F, \mathrm{D}_\chi\right], \mathrm{dW}_{\text {in }}\left[\mathrm{D}_\chi, F\right]$
1. $\operatorname{dOut}\left[\mathrm{B}_{\mathrm{X}}, \mathrm{D}\right]=\ldots$
2. $\mathrm{dW}_{\text {out }}[F, \mathrm{D}]\left\{\mathrm{U}_{\mathrm{X}}\right\}=\operatorname{Tmp}\left[\mathrm{B}_{\mathrm{X}}, F\right]{ }_{\mathrm{B}} \operatorname{dOut}\left[\mathrm{B}_{\mathrm{X}}, \mathrm{D}\right]$
3. $\mathrm{dW}_{\text {out }}\left[\mathrm{F}, \mathrm{D}_{\mathrm{X}}\right]=$ ReduceScatter $\left(\mathrm{dW}_{\text {out }}[\mathrm{F}, \mathrm{D}]\left\{\mathrm{U}_{\mathrm{X}}\right\}\right)$ (not on critical path, can be done async)
4. $\mathrm{W}_{\text {out }}[\mathrm{F}, \mathrm{D}]=$ AllGather $\left(\mathrm{W}_{\text {out }}\left[\mathrm{F}_{,} \mathrm{D}_{\times}\right]\right.$) (can be done ahead of time)
5. $\mathrm{dTmp}\left[\mathrm{B}_{\mathrm{X}}, \mathrm{F}\right]=\mathrm{dOut}\left[\mathrm{B}_{\mathrm{X}}, \mathrm{D}\right]{ }^*{ }_{\mathrm{D}} \mathrm{W}_{\text {out }}[\mathrm{F}, \mathrm{D}]$ (can throw away $W_{\text {out }}[F, D]$ here)
6. $d W_{i n}[D, F]\left\{U_X\right\}=d \operatorname{Tmp}\left[B_X, F\right]{ }_B \operatorname{In}\left[B_X, D\right]$
7. $\mathrm{dW}_{\text {in }}\left[\mathrm{D}_{\mathrm{X}}, \mathrm{F}\right]=$ ReduceScatter $\left(\mathrm{dW}_{\text {in }}[\mathrm{D}, \mathrm{F}]\left\{\mathrm{U}_{\mathrm{X}}\right\}\right)$ (not on critical path, can be done async)
8. $\mathrm{W}_{\text {in }}[\mathrm{D}, \mathrm{F}]=$ AllGather $\left(\mathrm{W}_{\text {in }}\left[\mathrm{D}_{\mathrm{X}}, \mathrm{F}\right]\right)$ (can be done ahead of time)
9. $\mathrm{d} \ln \left[\mathrm{B}_{\mathrm{X}}, \mathrm{D}\right]=\mathrm{dTmp}\left[\mathrm{B}_{\mathrm{X}}, \mathrm{F}\right]{ }_{\mathrm{F}} \mathrm{W}_{\mathrm{in}}[\mathrm{D}, \mathrm{F}]$ (needed for previous layers) (can throw away $W_{\text {in }}[D, F]$ here)
This is also called "ZeRO Sharding", from "ZeRo Overhead sharding" since we don't perform any unnecessary compute or store any unnecessary state. ZeRO- $\{1,2,3\}$ are used to refer to sharding the optimizer states, gradients, and weights in this way, respectively. Since all have the same communication cost ${ }^5$, we can basically always do ZeRO-3 sharding, which shards the parameters, gradients, and optimizer states across a set of devices.
Why would we do this? Standard data parallelism involves a lot of duplicated work. Each TPU AllReduces the full gradient, then updates the full optimizer state (identical work on all TPUs), then updates the parameters (again, fully duplicated). For ZeRO sharding (sharding the gradients/optimizer state), instead of an AllReduce, you can ReduceScatter the gradients, update only your shard of the optimizer state, update a shard of the parameters, then AllGather the parameters as needed for your forward pass.
When do we become bottlenecked by communication? Our relative FLOPs and comms costs are exactly the same as pure data parallelism, since each AllReduce in the backward pass has become an AllGather + ReduceScatter. Recall that an AllReduce is implemented as an AllGather and a ReduceScatter, each with half the cost. Here we model the forward pass since it has the same FLOPs-to-comms ratio as the backward pass:
$$
\begin{aligned}
T_{m a t h} & =\frac{2 \cdot 2 \cdot B \cdot D \cdot F}{X \cdot C} \\
T_{c o m m} & =\frac{2 \cdot 2 \cdot D \cdot F}{W_{\mathrm{ici}}} \\
T & \approx \max \left(\frac{4 \cdot B \cdot D \cdot F}{X \cdot C}, \frac{4 \cdot D \cdot F}{W_{\mathrm{ici}}}\right) \\
T & \approx 4 \cdot D \cdot F \cdot \max \left(\frac{B}{X \cdot C}, \frac{1}{W_{\mathrm{ici}}}\right)
\end{aligned}
$$
Therefore, as with pure data-parallelism, we are compute bound when $B / X>C / W_{\text {ici }}$, i.e. when the per-device batch size $B / X$ exceeds the "ICI operational intensity" $C / W_{\text {ici }}$ (4.59e14 $/ 1.8 \mathrm{e} 11=2550$ for v 5 p ). This is great for us, because it means if our per-device batch size is big enough to be compute-bound for pure data-parallelism, we can - without worrying about leaving the compute-bound regime - simply upgrade to FSDP, saving ourselves a massive amount of parameter and optimizer state memory! Though we did have to add communication to the forward pass, this cost is immaterial since it just overlaps with forward-pass FLOPs.
Takeaway: both FSDP and pure data parallelism become bandwidth bound on TPUv5 when the batch size per device is less than $2550 / n_{\text {axes }}$.
For example, DeepSeek-V2 (one of the only recent strong model to release information about its training batch size) used a batch size of $\sim 40 \mathrm{M}$ tokens. This would allow us to scale to roughly 47,000 chips, or around 5 TPUv5 pods, before we hit a bandwidth limit.
For LLaMA-3 70B, which was trained for approximately $6.3 \mathrm{e} 24(15 \mathrm{e} 12 * 70 \mathrm{e} 9 * 6)$ FLOPs, we could split a batch of 16 M tokens over roughly $16 \mathrm{e} 6 /(2550 / 3)=18,823$ chips (roughly 2 pods of 8960 chips), each with 4.59 e 14 FLOPs running at $50 \%$ peak FLOPs utilization (often called MFU), and train it in approximately 17 days. Not bad! But let's explore how we can do better.
Note on critical batch size: somewhat unintuitively, we become more communication bottlenecked as our total batch size decreases (with fixed chip number). Data parallelism and FSDP let us scale to arbitrarily many chips so long as we can keep increasing our batch size! However, in practice, as our batch size increases, we tend to see diminishing returns in training since our gradients become almost noise-free. We also sometimes see training instability. Thus, the game of finding an optimal sharding scheme in the "unlimited compute regime" often starts from a fixed batch size, determined by scaling laws, and a known (large) number of chips, and then aims to find a partitioning that allows us to fit that small batch size on so many chips.
Tensor Parallelism
Syntax: $\operatorname{In}\left[B, D_Y\right] \cdot{ }_D W_{\text {in }}\left[D, F_Y\right] \cdot{ }_F W_{\text {out }}\left[F_Y, D\right] \rightarrow \operatorname{Out}\left[B, D_Y\right]$ (we use $Y$ to eventually combine with FSDP)
In a fully-sharded data-parallel AllReduce we move the weights across chips. We can also shard the feedforward dimension of the model and move the activations during the layer - this is called "1D model parallelism" or Megatron sharding [2]. This can unlock a smaller efficient batch size per pod. The figure below shows an example of a single matrix sharded in this way:
Figure: an example of basic tensor parallelism. Since we're only sharding our activations over Y (unlike in FSDP where we shard over $X$ ), we replicate our activations over $X$. Using our standard syntax, this is $A\left[B, D_Y\right]$ * $B\left[D, F_Y\right]$-> $\mathrm{C}\left[\mathrm{B}, \mathrm{F}_{\mathrm{Y}}\right]$. Because we're only sharding over one of the contracting dimensions, we typically AllGather the activations A before the matmul.
As noted, $\ln \left[B, D_Y\right]{ }_D W_{\text {in }}\left[D, F_Y\right]{ }_F W_{\text {out }}\left[F_Y, D\right] \rightarrow$ Out $\left[B, D_Y\right]$ means we have to gather our activations before the first matmul. This is cheaper than ZeRO sharding when the activations are smaller than the weights. This is typically true only with some amount of ZeRO sharding added (which reduces the size of the gather). This is one of the reasons we tend to mix ZeRO sharding and model parallelism.
Here's the algorithm for tensor parallelism!
Tensor Parallelism:
Forward pass: need to compute Loss[B]
1. $\ln [\mathrm{B}, \mathrm{D}]=$ AllGather $\left(\ln \left[\mathrm{B}, \mathrm{D}_{\mathrm{Y}}\right]\right)$ (on critical path)
2. $\operatorname{Tmp}\left[\mathrm{B}, \mathrm{F}_{\mathrm{Y}}\right]=\ln [\mathrm{B}, \mathrm{D}]{ }^*{ }_{\mathrm{D}} \mathrm{W}_{\text {in }}\left[\mathrm{D}, \mathrm{F}_{\mathrm{Y}}\right]$ (not sharded along contracting, so no comms)
3. Out $[B, D]\left\{U_Y\right\}=\operatorname{Tmp}\left[B, F_Y\right]{ }^*{ }_F W_{\text {out }}\left[F_Y, D\right]$
4. Out $\left[\mathrm{B}, \mathrm{D}_{\mathrm{Y}}\right]=$ ReduceScatter(Out $[\mathrm{B}, \mathrm{D}]\left\{\mathrm{U}_{\mathrm{Y}}\right\}$ ) (on critical path)
5. $\operatorname{Loss}[\mathrm{B}]=\ldots$
Backward pass: need to compute $d W_{\text {out }}\left[F_Y, D\right], d W_{\text {in }}\left[D, F_Y\right]$
1. $\operatorname{dOut}\left[B, D_\gamma\right]=\ldots$
2. $\operatorname{dOut}[\mathrm{B}, \mathrm{D}]=$ AllGather $\left(\mathrm{dOut}\left[\mathrm{B}, \mathrm{D}_\gamma\right]\right)$ (on critical path)
3. $\mathrm{dW}_{\text {out }}\left[F_Y, \mathrm{D}\right]=\operatorname{Tmp}\left[B, F_Y\right]{ }^{\star} \mathrm{dOut}[B, D]$
4. $\mathrm{dTmp}\left[\mathrm{B}, \mathrm{F}_{\mathrm{Y}}\right]=\operatorname{dOut}[\mathrm{B}, \mathrm{D}]{ }^*{ }_{\mathrm{D}} \mathrm{W}_{\text {out }}\left[\mathrm{F}_{\mathrm{Y}}, \mathrm{D}\right]$ (can throw away dOut $[B, D]$ here)
5. $\ln [\mathrm{B}, \mathrm{D}]=$ AllGather $\left(\ln \left[\mathrm{B}, \mathrm{D}_\gamma\right]\right)$ (this can be skipped by sharing with (1) from the forward pass)
6. $d W_{i n}\left[D, F_Y\right]=d \operatorname{Tmp}\left[B, F_Y\right]{ }_B \operatorname{In}[B, D]$
7. $\operatorname{dln}[B, D]\{U . Y\}=d T m p\left[B, F_Y\right]{ }_F W_{i n}\left[D, F_Y\right]$ (needed for previous layers)
8. $\operatorname{dln}\left[B, D_Y\right]=$ ReduceScatter(dln[B, D] \{U.Y\}) (on critical path)
One nice thing about tensor parallelism is that it interacts nicely with the two matrices in our Transformer forward pass. Naively, we would do an AllReduce after each of the two matrices. But here we first do $\ln \left[\mathbf{B}, \mathbf{D}_{\mathbf{Y}}\right] * \mathbf{W}_{\text {in }}\left[\mathbf{D}, \mathbf{F}_{\mathbf{Y}}\right] \rightarrow \operatorname{Tmp}\left[\mathbf{B}, \mathbf{F}_{\mathbf{Y}}\right]$ and then $\operatorname{Tmp}\left[\mathbf{B}, \mathbf{F}_{\mathbf{Y}}\right] * \mathbf{W}_{\text {out }}\left[\mathbf{F}_{\mathbf{Y}}, \mathbf{D}\right] \rightarrow$ Out[B, $\mathbf{D}_{\mathbf{Y}}$ ]. This means we AllGather In at the beginning, and ReduceScatter Out at the end, rather than doing an AllReduce.
How costly is this? Let's only model the forward pass - the backwards pass is just the transpose of each operation here. In 1D model parallelism we AllGather the activations before the first matmul, and ReduceScatter them after the second, sending two bytes at a time (bf16). Let's figure out when we're bottlenecked by communication.
$$
\begin{aligned}
T_{m a t h} & =\frac{4 \cdot B \cdot D \cdot F}{Y \cdot C} \\
T_{c o m m s} & =\frac{2 \cdot 2 \cdot(B \cdot D)}{W_{\mathrm{ici}}} \\
\mathrm{~T} & \approx \max \left(\frac{4 \cdot B \cdot D \cdot F}{Y \cdot C}, \frac{2 \cdot 2 \cdot(B \cdot D)}{W_{\mathrm{ici}}}\right)
\end{aligned}
$$
Noting that we want compute cost to be greater than comms cost, we get:
$$
\begin{gathered}
\frac{4 \cdot B \cdot D \cdot F}{Y \cdot C}>\frac{2 \cdot 2 \cdot(B \cdot D)}{W_{\mathrm{ici}}} \\
\frac{F}{Y \cdot C}>\frac{1}{W_{\mathrm{ici}}} \\
F>Y \cdot \frac{C}{W_{\mathrm{ici}}}
\end{gathered}
$$
Thus for instance, for TPUv5p, $C / W_{i c i}=2550$ in bf16, so we can only do tensor parallelism up to $Y<F / 2550$. When we have multiple ICI axes, our $T_{\text {comms }}$ is reduced by a factor of $n_{\text {axes }}$, so we get $Y<n_{\text {axes }} * F / 2550$.
Takeaway: model parallelism becomes communication bound when $Y>n_{\text {axes }} * F / 2550$. For most models this is between 8 and 16-way model parallelism.
Note that this doesn't depend on the precision of the computation, since e.g. for int8, on TPUv5p, $C_{\text {int8 }} / W_{i c i}$ is 5100 instead of 2550 but the comms volume is also halved, so the two factors of two cancel.
Let's think about some examples:
- On TPUv4p with LLaMA 3-70B with $D=8192, F \approx 30,000$, we can comfortably do 8way model parallelism, but will be communication bound on 16 way model parallelism. The required $F$ for model 8 way model sharding is 20 k .
- For Gemma 7B, $F \approx 50 k$, so we become communication bound with 19-way model parallelism. That means we could likely do 16-way and still see good performance.