Skip to content

Commit 1a90450

Browse files
authored
fix good tokens reference
1 parent 65635eb commit 1a90450

File tree

1 file changed

+2
-3
lines changed
  • examples/alst_ulysses_sequence_parallelism

1 file changed

+2
-3
lines changed

examples/alst_ulysses_sequence_parallelism/sp-alst.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,16 +127,15 @@ def collate_fn(batch):
127127
batch = move_to_device(batch, model.device)
128128

129129
# The model automatically receives shift_labels via **kwargs and uses it for loss computation.
130-
# Both standard transformers models and Liger-patched models handle this correctly.
130+
# Both standard transformer models and Liger-patched models handle this correctly.
131131
outputs = model(**batch)
132132
loss = outputs.loss
133-
shift_labels = batch["shift_labels"]
134133

135134
if sp_size > 1:
136135
# differentiable weighted per-shard-loss aggregation across ranks
137136
losses_per_rank = torch.distributed.nn.functional.all_gather(loss, group=sp_group)
138137
# special dealing with SFT that has prompt tokens that aren't used in loss computation
139-
good_tokens = (shift_labels != -100).view(-1).sum()
138+
good_tokens = (batch["shift_labels"] != -100).view(-1).sum()
140139
good_tokens_per_rank = torch.distributed.nn.functional.all_gather(good_tokens, group=sp_group)
141140
total_loss = sum(losses_per_rank[rank] * good_tokens_per_rank[rank] for rank in range(sp_world_size))
142141
total_good_tokens = sum(good_tokens_per_rank)

0 commit comments

Comments
 (0)