Skip to content

Commit

Permalink
[Not for land] Added changes for GPT-2 perf
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
awgu committed Aug 19, 2024
1 parent 8afa545 commit 6ad9afa
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 6 deletions.
4 changes: 2 additions & 2 deletions torchtitan/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,8 +439,8 @@ def forward(self, tokens: torch.Tensor):
h = layer(h, self.freqs_cis)

h = self.norm(h) if self.norm else h
output = self.output(h).float() if self.output else h
return output
chunks = h.chunk(16, dim=1) # TODO: 16 is from the default `num_chunks`
return [self.output(chunk) for chunk in chunks]

@classmethod
def from_model_args(cls, model_args: ModelArgs) -> "Transformer":
Expand Down
8 changes: 4 additions & 4 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,15 +312,15 @@ def apply_fsdp(
# all-gathers, which can be expensive and non-overlapped
reshard_after_forward = False
else:
# As an optimization, do not reshard after forward for the last
# transformer block since FSDP would prefetch it immediately
reshard_after_forward = int(layer_id) < len(model.layers) - 1
# For small models (e.g. GPT-2), parameter memory is low, so there
# is no need to reshard after forward
reshard_after_forward = False
fully_shard(
transformer_block,
**fsdp_config,
reshard_after_forward=reshard_after_forward,
)
fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled)
fully_shard(model, **fsdp_config)

logger.info("Applied FSDP to the model")

Expand Down
38 changes: 38 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import os
import time
from datetime import timedelta
from typing import List

import torch
from torch.distributed.elastic.multiprocessing.errors import record
Expand Down Expand Up @@ -44,6 +45,36 @@ def context():
return context


class TokenChunkedCrossEntropyLoss(torch.nn.Module):
def __init__(self, num_chunks: int = 16, ignore_index: int = -100):
super(TokenChunkedCrossEntropyLoss, self).__init__()
self.num_chunks = num_chunks
self.ignore_index = ignore_index
self.cross_entropy_loss = torch.nn.CrossEntropyLoss(
reduction="sum", ignore_index=self.ignore_index
)

@torch.compile()
def _compute_cross_entropy(self, logits: torch.Tensor, labels: torch.Tensor):
return self.cross_entropy_loss(logits.float(), labels)

def forward(self, logits: List[torch.Tensor], labels: torch.Tensor):
"""
Args:
logits (List[torch.Tensor]): List of chunked logits of length
``self.num_chunks``, where each chunk has shape
(batch_size, num_tokens / num_chunks, vocab_size).
labels (torch.Tensor): Ground truth labels of shape (batch_size, num_tokens).
"""
total_elements = (labels != self.ignore_index).sum()
labels = [target_chunk.reshape(-1) for target_chunk in labels.chunk(self.num_chunks, dim=1)]
logits = [logit_chunk.reshape(-1, logit_chunk.size(-1)) for logit_chunk in logits]
total_loss = 0.0
for logits_chunk, labels_chunk in zip(logits, labels):
total_loss += self._compute_cross_entropy(logits_chunk, labels_chunk)
return total_loss / total_elements


# Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html
@record
def main(job_config: JobConfig):
Expand Down Expand Up @@ -132,9 +163,16 @@ def main(job_config: JobConfig):
f"{color.blue}Model {model_name} {job_config.model.flavor} "
f"{color.red}size: {model_param_count:,} total parameters{color.reset}"
)
token_chunked_cross_entropy_loss = TokenChunkedCrossEntropyLoss()

# loss function to be shared by Pipeline Parallel and SPMD training
def loss_fn(pred, labels):
if isinstance(pred, torch.Tensor):
pred_chunks = pred.chunk(token_chunked_cross_entropy_loss.num_chunks, dim=1)
else:
assert isinstance(pred, list)
pred_chunks = pred
return token_chunked_cross_entropy_loss(pred_chunks, labels)
return torch.nn.functional.cross_entropy(
pred.flatten(0, 1), labels.flatten(0, 1)
)
Expand Down

0 comments on commit 6ad9afa

Please sign in to comment.