Skip to content

Commit 997eae8

Browse files
committed
fix loss computation
1 parent b521400 commit 997eae8

File tree

2 files changed

+9
-16
lines changed

2 files changed

+9
-16
lines changed

docs/source/concept_guides/sequence_parallelism.md

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -171,16 +171,12 @@ for iter, batch in enumerate(dl):
171171
optimizer.zero_grad()
172172

173173
batch = move_to_device(batch, model.device)
174-
outputs = model(**batch)
175174

176-
# only if not using liger-kernel
175+
# The model automatically receives shift_labels via **kwargs and uses it for loss computation.
176+
# Both standard transformers models and Liger-patched models handle this correctly.
177+
outputs = model(**batch)
178+
loss = outputs.loss
177179
shift_labels = batch["shift_labels"]
178-
loss = unwrapped_model.loss_function(
179-
logits=outputs.logits,
180-
labels=None,
181-
shift_labels=shift_labels,
182-
vocab_size=unwrapped_model.config.vocab_size,
183-
)
184180

185181
if sp_size > 1:
186182
# differentiable weighted per-shard-loss aggregation across ranks
@@ -206,7 +202,7 @@ for iter, batch in enumerate(dl):
206202
optimizer.step()
207203
```
208204

209-
If you use [Liger Kernel](https://github.com/linkedin/Liger-Kernel) it already knows how to handle `shift_labels` so you don't need to go through manual loss calculation, just calling `model(**batch)` will already get the `loss` calculated and done in a very memory-efficient way. If you didn't know about Liger-Kernel - it's highly recommended to be used especially for long sequence length, since it liberates a lot of working GPU memory that can be used for handling longer sequences. For example, it performs a fused logit-loss computation, never manifesting the full logits tensor in memory.
205+
Note that models automatically handle `shift_labels` when it's present in the batch. The model's forward pass receives `shift_labels` via `**kwargs` and passes it to the loss function, which correctly computes the loss for sequence parallelism. If you use [Liger Kernel](https://github.com/linkedin/Liger-Kernel), it also handles `shift_labels` seamlessly and computes loss in a very memory-efficient way. Liger is highly recommended for long sequence lengths, as it liberates GPU memory by using fused operations (e.g., fused logit-loss computation that never materializes the full logits tensor in memory).
210206

211207
If you want to see what HF Accelerate did behind the scenes please read [this full integration tutorial](https://www.deepspeed.ai/tutorials/ulysses-alst-sequence-parallelism/).
212208

examples/alst_ulysses_sequence_parallelism/sp-alst.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -125,15 +125,12 @@ def collate_fn(batch):
125125
if rank == 0:
126126
print(f"batch {iter}: seqlen: {len(batch['input_ids'][0])}")
127127
batch = move_to_device(batch, model.device)
128-
outputs = model(**batch)
129128

129+
# The model automatically receives shift_labels via **kwargs and uses it for loss computation.
130+
# Both standard transformers models and Liger-patched models handle this correctly.
131+
outputs = model(**batch)
132+
loss = outputs.loss
130133
shift_labels = batch["shift_labels"]
131-
loss = unwrapped_model.loss_function(
132-
logits=outputs.logits,
133-
labels=None,
134-
shift_labels=shift_labels,
135-
vocab_size=unwrapped_model.config.vocab_size,
136-
)
137134

138135
if sp_size > 1:
139136
# differentiable weighted per-shard-loss aggregation across ranks

0 commit comments

Comments
 (0)