|
8 | 8 | import os
|
9 | 9 | import time
|
10 | 10 | from datetime import timedelta
|
| 11 | +from typing import List |
11 | 12 |
|
12 | 13 | import torch
|
13 | 14 | from torch.distributed.elastic.multiprocessing.errors import record
|
@@ -44,6 +45,36 @@ def context():
|
44 | 45 | return context
|
45 | 46 |
|
46 | 47 |
|
| 48 | +class TokenChunkedCrossEntropyLoss(torch.nn.Module): |
| 49 | + def __init__(self, num_chunks: int = 16, ignore_index: int = -100): |
| 50 | + super(TokenChunkedCrossEntropyLoss, self).__init__() |
| 51 | + self.num_chunks = num_chunks |
| 52 | + self.ignore_index = ignore_index |
| 53 | + self.cross_entropy_loss = torch.nn.CrossEntropyLoss( |
| 54 | + reduction="sum", ignore_index=self.ignore_index |
| 55 | + ) |
| 56 | + |
| 57 | + @torch.compile() |
| 58 | + def _compute_cross_entropy(self, logits: torch.Tensor, labels: torch.Tensor): |
| 59 | + return self.cross_entropy_loss(logits.float(), labels) |
| 60 | + |
| 61 | + def forward(self, logits: List[torch.Tensor], labels: torch.Tensor): |
| 62 | + """ |
| 63 | + Args: |
| 64 | + logits (List[torch.Tensor]): List of chunked logits of length |
| 65 | + ``self.num_chunks``, where each chunk has shape |
| 66 | + (batch_size, num_tokens / num_chunks, vocab_size). |
| 67 | + labels (torch.Tensor): Ground truth labels of shape (batch_size, num_tokens). |
| 68 | + """ |
| 69 | + total_elements = (labels != self.ignore_index).sum() |
| 70 | + labels = [target_chunk.reshape(-1) for target_chunk in labels.chunk(self.num_chunks, dim=1)] |
| 71 | + logits = [logit_chunk.reshape(-1, logit_chunk.size(-1)) for logit_chunk in logits] |
| 72 | + total_loss = 0.0 |
| 73 | + for logits_chunk, labels_chunk in zip(logits, labels): |
| 74 | + total_loss += self._compute_cross_entropy(logits_chunk, labels_chunk) |
| 75 | + return total_loss / total_elements |
| 76 | + |
| 77 | + |
47 | 78 | # Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html
|
48 | 79 | @record
|
49 | 80 | def main(job_config: JobConfig):
|
@@ -132,9 +163,16 @@ def main(job_config: JobConfig):
|
132 | 163 | f"{color.blue}Model {model_name} {job_config.model.flavor} "
|
133 | 164 | f"{color.red}size: {model_param_count:,} total parameters{color.reset}"
|
134 | 165 | )
|
| 166 | + token_chunked_cross_entropy_loss = TokenChunkedCrossEntropyLoss() |
135 | 167 |
|
136 | 168 | # loss function to be shared by Pipeline Parallel and SPMD training
|
137 | 169 | def loss_fn(pred, labels):
|
| 170 | + if isinstance(pred, torch.Tensor): |
| 171 | + pred_chunks = pred.chunk(token_chunked_cross_entropy_loss.num_chunks, dim=1) |
| 172 | + else: |
| 173 | + assert isinstance(pred, list) |
| 174 | + pred_chunks = pred |
| 175 | + return token_chunked_cross_entropy_loss(pred_chunks, labels) |
138 | 176 | return torch.nn.functional.cross_entropy(
|
139 | 177 | pred.flatten(0, 1), labels.flatten(0, 1)
|
140 | 178 | )
|
|
0 commit comments