Skip to content

Commit

Permalink
Update on "[Not for land] Added changes for GPT-2 perf"
Browse files Browse the repository at this point in the history
Credit: felipemello1 for most of the work here (especially around chunked cross entropy)

Running on 4xH100s:
Without these changes (`torch.compile`), the max local batch size is 5:
```
[rank0]:2024-08-19 11:10:26,196 - root - INFO - Training starts at step 1, with local batch size 5, global batch size 20, sequence length 8192, total steps 100 (warmup 200)
[rank0]:/data/users/andgu/pytorch/torch/_inductor/lowering.py:1673: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.
[rank0]:  warnings.warn(
[rank0]:2024-08-19 11:10:33,811 - root - INFO - step:  1  loss: 12.2365  memory: 81.67GiB(85.93%)  wps: 5,380  mfu: 1.09%
[rank0]:2024-08-19 11:10:33,811 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-08-19 11:10:37,109 - root - INFO - step: 10  loss: 12.1951  memory: 81.67GiB(85.93%)  wps: 111,770  mfu: 22.68%
[rank0]:2024-08-19 11:10:40,777 - root - INFO - step: 20  loss: 11.9455  memory: 81.67GiB(85.93%)  wps: 111,714  mfu: 22.67%
[rank0]:2024-08-19 11:10:44,428 - root - INFO - step: 30  loss: 11.0407  memory: 81.67GiB(85.93%)  wps: 112,194  mfu: 22.76%
[rank0]:2024-08-19 11:10:48,083 - root - INFO - step: 40  loss:  9.9520  memory: 81.67GiB(85.93%)  wps: 112,109  mfu: 22.75%
[rank0]:2024-08-19 11:10:51,734 - root - INFO - step: 50  loss:  9.3392  memory: 81.67GiB(85.93%)  wps: 112,218  mfu: 22.77%
[rank0]:2024-08-19 11:10:55,386 - root - INFO - step: 60  loss:  8.7255  memory: 81.67GiB(85.93%)  wps: 112,198  mfu: 22.77%
[rank0]:2024-08-19 11:10:59,037 - root - INFO - step: 70  loss:  8.1659  memory: 81.67GiB(85.93%)  wps: 112,234  mfu: 22.77%
[rank0]:2024-08-19 11:11:02,701 - root - INFO - step: 80  loss:  7.8037  memory: 81.67GiB(85.93%)  wps: 111,802  mfu: 22.68%
[rank0]:2024-08-19 11:11:06,361 - root - INFO - step: 90  loss:  7.5327  memory: 81.67GiB(85.93%)  wps: 111,937  mfu: 22.71%
[rank0]:2024-08-19 11:11:10,026 - root - INFO - step: 100  loss:  7.3730  memory: 81.67GiB(85.93%)  wps: 111,803  mfu: 22.69%
```
Without these changes (no `torch.compile`), local batch size 5:
```
[rank0]:2024-08-19 14:24:32,150 - root - INFO - Training starts at step 1, with local batch size 5, global batch size 20, sequence length 8192, total steps 100 (warmup 200)
[rank0]:2024-08-19 14:24:38,558 - root - INFO - step:  1  loss: 12.2581  memory: 86.47GiB(90.99%)  wps: 6,393  mfu: 1.30%
[rank0]:2024-08-19 14:24:38,558 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-08-19 14:24:42,308 - root - INFO - step: 10  loss: 12.2099  memory: 86.48GiB(90.99%)  wps: 98,305  mfu: 19.95%
[rank0]:2024-08-19 14:24:46,482 - root - INFO - step: 20  loss: 11.9421  memory: 86.48GiB(90.99%)  wps: 98,230  mfu: 19.93%
[rank0]:2024-08-19 14:24:50,648 - root - INFO - step: 30  loss: 11.0090  memory: 86.48GiB(90.99%)  wps: 98,435  mfu: 19.97%
[rank0]:2024-08-19 14:24:54,788 - root - INFO - step: 40  loss:  9.9780  memory: 86.48GiB(90.99%)  wps: 99,064  mfu: 20.10%
[rank0]:2024-08-19 14:24:58,936 - root - INFO - step: 50  loss:  9.3572  memory: 86.48GiB(90.99%)  wps: 98,813  mfu: 20.05%
[rank0]:2024-08-19 14:25:03,181 - root - INFO - step: 60  loss:  8.7479  memory: 86.48GiB(90.99%)  wps: 96,567  mfu: 19.59%
[rank0]:2024-08-19 14:25:07,339 - root - INFO - step: 70  loss:  8.1769  memory: 86.48GiB(90.99%)  wps: 98,604  mfu: 20.01%
[rank0]:2024-08-19 14:25:11,497 - root - INFO - step: 80  loss:  7.8070  memory: 86.48GiB(90.99%)  wps: 98,579  mfu: 20.00%
[rank0]:2024-08-19 14:25:15,649 - root - INFO - step: 90  loss:  7.5329  memory: 86.48GiB(90.99%)  wps: 98,743  mfu: 20.04%
[rank0]:2024-08-19 14:25:19,798 - root - INFO - step: 100  loss:  7.3700  memory: 86.48GiB(90.99%)  wps: 98,818  mfu: 20.05%
```

With these changes, we can use local batch size 16:
```
[rank0]:2024-08-19 11:16:09,534 - root - INFO - Training starts at step 1, with local batch size 16, global batch size 64, sequence length 8192, total steps 100 (warmup 200)
[rank0]:/data/users/andgu/pytorch/torch/_inductor/lowering.py:1673: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.
[rank0]:  warnings.warn(
[rank0]:2024-08-19 11:16:15,523 - root - INFO - step:  1  loss: 12.2386  memory: 72.29GiB(76.06%)  wps: 21,887  mfu: 4.44%
[rank0]:2024-08-19 11:16:15,523 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-08-19 11:16:22,538 - root - INFO - step: 10  loss: 12.1966  memory: 72.30GiB(76.07%)  wps: 168,174  mfu: 34.12%
[rank0]:2024-08-19 11:16:30,332 - root - INFO - step: 20  loss: 11.9229  memory: 72.30GiB(76.07%)  wps: 168,196  mfu: 34.13%
[rank0]:2024-08-19 11:16:38,129 - root - INFO - step: 30  loss: 10.9399  memory: 72.30GiB(76.07%)  wps: 168,144  mfu: 34.12%
[rank0]:2024-08-19 11:16:45,937 - root - INFO - step: 40  loss:  9.8742  memory: 72.30GiB(76.07%)  wps: 167,898  mfu: 34.07%
[rank0]:2024-08-19 11:16:53,734 - root - INFO - step: 50  loss:  9.2517  memory: 72.30GiB(76.07%)  wps: 168,130  mfu: 34.11%
[rank0]:2024-08-19 11:17:01,518 - root - INFO - step: 60  loss:  8.6441  memory: 72.30GiB(76.07%)  wps: 168,435  mfu: 34.18%
[rank0]:2024-08-19 11:17:09,279 - root - INFO - step: 70  loss:  8.0827  memory: 72.30GiB(76.07%)  wps: 168,927  mfu: 34.28%
[rank0]:2024-08-19 11:17:17,047 - root - INFO - step: 80  loss:  7.7330  memory: 72.30GiB(76.07%)  wps: 168,772  mfu: 34.24%
[rank0]:2024-08-19 11:17:25,139 - root - INFO - step: 90  loss:  7.4835  memory: 72.30GiB(76.07%)  wps: 162,008  mfu: 32.87%
[rank0]:2024-08-19 11:17:32,944 - root - INFO - step: 100  loss:  7.3274  memory: 72.30GiB(76.07%)  wps: 167,963  mfu: 34.08%
```

22.7% MFU -> 34.1% MFU


[ghstack-poisoned]
  • Loading branch information
awgu committed Aug 19, 2024
1 parent 6ad9afa commit b4a24d2
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions torchtitan/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,8 @@ def forward(self, tokens: torch.Tensor):
h = layer(h, self.freqs_cis)

h = self.norm(h) if self.norm else h
if not self.output:
return h
chunks = h.chunk(16, dim=1) # TODO: 16 is from the default `num_chunks`
return [self.output(chunk) for chunk in chunks]

Expand Down

0 comments on commit b4a24d2

Please sign in to comment.