Skip to content

Commit facf34d

Browse files
author
--global
committed
Updated device of num_items_in_batch and n_items (same one) to loss where they are calculated with.
1 parent 37dfb23 commit facf34d

File tree

2 files changed

+24
-11
lines changed

2 files changed

+24
-11
lines changed

unsloth_zoo/compiler.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -623,7 +623,7 @@ def _compiled_loss_function(
623623
logit_scale_divide : float = 0,
624624
logit_softcapping : float = 0,
625625
vocab_size : int = 0,
626-
n_items : int = 0,
626+
n_items : int | torch.Tensor = 0,
627627
):
628628
device = output_logits.device
629629
if logit_scale_multiply != 0:
@@ -658,6 +658,8 @@ def _compiled_loss_function(
658658
)
659659
pass
660660
if n_items != 0:
661+
if torch.is_tensor(n_items):
662+
n_items = n_items.to(loss.device)
661663
loss = loss / n_items
662664
else:
663665
loss = loss / (shift_labels != -100).sum()
@@ -743,7 +745,7 @@ def _compiled_loss_function(
743745
logit_scale_divide : float = 0,
744746
logit_softcapping : float = 0,
745747
vocab_size : int = 0,
746-
n_items : int = 0,
748+
n_items : int | torch.Tensor = 0,
747749
):
748750
device = output_logits.device
749751
if logit_scale_multiply != 0:
@@ -778,6 +780,8 @@ def _compiled_loss_function(
778780
)
779781
pass
780782
if n_items != 0:
783+
if torch.is_tensor(n_items):
784+
n_items = n_items.to(loss.device)
781785
loss = loss / n_items
782786
else:
783787
loss = loss / (shift_labels != -100).sum()
@@ -851,7 +855,7 @@ def _compiled_loss_function(
851855
logit_scale_divide : float = 0,
852856
logit_softcapping : float = 0,
853857
vocab_size : int = 0,
854-
n_items : int = 0,
858+
n_items : int | torch.Tensor = 0,
855859
):
856860
device = output_logits.device
857861
if logit_scale_multiply != 0:
@@ -886,6 +890,8 @@ def _compiled_loss_function(
886890
)
887891
pass
888892
if n_items != 0:
893+
if torch.is_tensor(n_items):
894+
n_items = n_items.to(loss.device)
889895
loss = loss / n_items
890896
else:
891897
loss = loss / (shift_labels != -100).sum()

unsloth_zoo/loss_utils.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -83,13 +83,16 @@ def unsloth_fixed_cross_entropy(source, target, num_items_in_batch: int = None,
8383
ignore_index = ignore_index,
8484
reduction = reduction,
8585
)
86-
if reduction == "sum": loss = loss / num_items_in_batch
86+
if reduction == "sum":
87+
if torch.is_tensor(num_items_in_batch):
88+
num_items_in_batch = num_items_in_batch.to(loss.device)
89+
loss = loss / num_items_in_batch
8790
return loss
8891
pass
8992

9093
# Causal LM loss
9194
def UnslothForCausalLMLoss(
92-
logits, labels, vocab_size: int, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs
95+
logits, labels, vocab_size, num_items_in_batch= None, ignore_index = -100, **kwargs
9396
):
9497
if labels is None: return None
9598
shift_logits = logits
@@ -159,7 +162,7 @@ def fused_linear_cross_entropy(
159162
hidden_states : torch.Tensor,
160163
lm_weight : torch.Tensor,
161164
labels : torch.Tensor,
162-
num_items_in_batch : int = None,
165+
num_items_in_batch : int | torch.Tensor = None,
163166
ignore_index : int = -100,
164167
reduction : str = "mean",
165168
logit_softcapping : float = 0,
@@ -179,7 +182,10 @@ def fused_linear_cross_entropy(
179182
shift = True,
180183
filter_eps = accuracy_threshold,
181184
)
182-
if num_items_in_batch is not None: loss = loss / num_items_in_batch
185+
if num_items_in_batch is not None:
186+
if torch.is_tensor(num_items_in_batch):
187+
num_items_in_batch = num_items_in_batch.to(loss.device)
188+
loss = loss / num_items_in_batch
183189
return loss
184190
pass
185191

@@ -188,7 +194,7 @@ def fast_linear_cross_entropy(
188194
hidden_states : torch.Tensor,
189195
lm_head : torch.nn.Linear,
190196
labels : torch.Tensor,
191-
num_items_in_batch : int = None,
197+
num_items_in_batch : int | torch.Tensor = None,
192198
ignore_index : int = -100,
193199
reduction : str = "mean",
194200
logit_softcapping : float = 0,
@@ -218,7 +224,10 @@ def fast_linear_cross_entropy(
218224
chunk_size = 512,
219225
attention_mask = attention_mask,
220226
)
221-
if num_items_in_batch is not None: loss = loss / num_items_in_batch
227+
if num_items_in_batch is not None:
228+
if torch.is_tensor(num_items_in_batch):
229+
num_items_in_batch = num_items_in_batch.to(loss.device)
230+
loss = loss / num_items_in_batch
222231
return loss
223232
pass
224233

@@ -292,8 +301,6 @@ def _unsloth_get_batch_samples(self, epoch_iterator, num_batches, device = None,
292301

293302
if self.args.average_tokens_across_devices:
294303
num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum()
295-
if device is not None and torch.is_tensor(num_items_in_batch):
296-
num_items_in_batch = num_items_in_batch.to(device)
297304
except Exception as exception:
298305
raise RuntimeError(exception)
299306
pass

0 commit comments

Comments
 (0)