Skip to content

Commit 6ad9afa

Browse files
author
Andrew Gu
committed
[Not for land] Added changes for GPT-2 perf
[ghstack-poisoned]
1 parent 8afa545 commit 6ad9afa

File tree

3 files changed

+44
-6
lines changed

3 files changed

+44
-6
lines changed

torchtitan/models/llama/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -439,8 +439,8 @@ def forward(self, tokens: torch.Tensor):
439439
h = layer(h, self.freqs_cis)
440440

441441
h = self.norm(h) if self.norm else h
442-
output = self.output(h).float() if self.output else h
443-
return output
442+
chunks = h.chunk(16, dim=1) # TODO: 16 is from the default `num_chunks`
443+
return [self.output(chunk) for chunk in chunks]
444444

445445
@classmethod
446446
def from_model_args(cls, model_args: ModelArgs) -> "Transformer":

torchtitan/parallelisms/parallelize_llama.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -312,15 +312,15 @@ def apply_fsdp(
312312
# all-gathers, which can be expensive and non-overlapped
313313
reshard_after_forward = False
314314
else:
315-
# As an optimization, do not reshard after forward for the last
316-
# transformer block since FSDP would prefetch it immediately
317-
reshard_after_forward = int(layer_id) < len(model.layers) - 1
315+
# For small models (e.g. GPT-2), parameter memory is low, so there
316+
# is no need to reshard after forward
317+
reshard_after_forward = False
318318
fully_shard(
319319
transformer_block,
320320
**fsdp_config,
321321
reshard_after_forward=reshard_after_forward,
322322
)
323-
fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled)
323+
fully_shard(model, **fsdp_config)
324324

325325
logger.info("Applied FSDP to the model")
326326

train.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import os
99
import time
1010
from datetime import timedelta
11+
from typing import List
1112

1213
import torch
1314
from torch.distributed.elastic.multiprocessing.errors import record
@@ -44,6 +45,36 @@ def context():
4445
return context
4546

4647

48+
class TokenChunkedCrossEntropyLoss(torch.nn.Module):
49+
def __init__(self, num_chunks: int = 16, ignore_index: int = -100):
50+
super(TokenChunkedCrossEntropyLoss, self).__init__()
51+
self.num_chunks = num_chunks
52+
self.ignore_index = ignore_index
53+
self.cross_entropy_loss = torch.nn.CrossEntropyLoss(
54+
reduction="sum", ignore_index=self.ignore_index
55+
)
56+
57+
@torch.compile()
58+
def _compute_cross_entropy(self, logits: torch.Tensor, labels: torch.Tensor):
59+
return self.cross_entropy_loss(logits.float(), labels)
60+
61+
def forward(self, logits: List[torch.Tensor], labels: torch.Tensor):
62+
"""
63+
Args:
64+
logits (List[torch.Tensor]): List of chunked logits of length
65+
``self.num_chunks``, where each chunk has shape
66+
(batch_size, num_tokens / num_chunks, vocab_size).
67+
labels (torch.Tensor): Ground truth labels of shape (batch_size, num_tokens).
68+
"""
69+
total_elements = (labels != self.ignore_index).sum()
70+
labels = [target_chunk.reshape(-1) for target_chunk in labels.chunk(self.num_chunks, dim=1)]
71+
logits = [logit_chunk.reshape(-1, logit_chunk.size(-1)) for logit_chunk in logits]
72+
total_loss = 0.0
73+
for logits_chunk, labels_chunk in zip(logits, labels):
74+
total_loss += self._compute_cross_entropy(logits_chunk, labels_chunk)
75+
return total_loss / total_elements
76+
77+
4778
# Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html
4879
@record
4980
def main(job_config: JobConfig):
@@ -132,9 +163,16 @@ def main(job_config: JobConfig):
132163
f"{color.blue}Model {model_name} {job_config.model.flavor} "
133164
f"{color.red}size: {model_param_count:,} total parameters{color.reset}"
134165
)
166+
token_chunked_cross_entropy_loss = TokenChunkedCrossEntropyLoss()
135167

136168
# loss function to be shared by Pipeline Parallel and SPMD training
137169
def loss_fn(pred, labels):
170+
if isinstance(pred, torch.Tensor):
171+
pred_chunks = pred.chunk(token_chunked_cross_entropy_loss.num_chunks, dim=1)
172+
else:
173+
assert isinstance(pred, list)
174+
pred_chunks = pred
175+
return token_chunked_cross_entropy_loss(pred_chunks, labels)
138176
return torch.nn.functional.cross_entropy(
139177
pred.flatten(0, 1), labels.flatten(0, 1)
140178
)

0 commit comments

Comments
 (0)