File tree Expand file tree Collapse file tree 1 file changed +2
-3
lines changed
examples/alst_ulysses_sequence_parallelism Expand file tree Collapse file tree 1 file changed +2
-3
lines changed Original file line number Diff line number Diff line change @@ -127,16 +127,15 @@ def collate_fn(batch):
127127 batch = move_to_device (batch , model .device )
128128
129129 # The model automatically receives shift_labels via **kwargs and uses it for loss computation.
130- # Both standard transformers models and Liger-patched models handle this correctly.
130+ # Both standard transformer models and Liger-patched models handle this correctly.
131131 outputs = model (** batch )
132132 loss = outputs .loss
133- shift_labels = batch ["shift_labels" ]
134133
135134 if sp_size > 1 :
136135 # differentiable weighted per-shard-loss aggregation across ranks
137136 losses_per_rank = torch .distributed .nn .functional .all_gather (loss , group = sp_group )
138137 # special dealing with SFT that has prompt tokens that aren't used in loss computation
139- good_tokens = (shift_labels != - 100 ).view (- 1 ).sum ()
138+ good_tokens = (batch [ " shift_labels" ] != - 100 ).view (- 1 ).sum ()
140139 good_tokens_per_rank = torch .distributed .nn .functional .all_gather (good_tokens , group = sp_group )
141140 total_loss = sum (losses_per_rank [rank ] * good_tokens_per_rank [rank ] for rank in range (sp_world_size ))
142141 total_good_tokens = sum (good_tokens_per_rank )
You can’t perform that action at this time.
0 commit comments