@@ -83,13 +83,16 @@ def unsloth_fixed_cross_entropy(source, target, num_items_in_batch: int = None,
83
83
ignore_index = ignore_index ,
84
84
reduction = reduction ,
85
85
)
86
- if reduction == "sum" : loss = loss / num_items_in_batch
86
+ if reduction == "sum" :
87
+ if torch .is_tensor (num_items_in_batch ):
88
+ num_items_in_batch = num_items_in_batch .to (loss .device )
89
+ loss = loss / num_items_in_batch
87
90
return loss
88
91
pass
89
92
90
93
# Causal LM loss
91
94
def UnslothForCausalLMLoss (
92
- logits , labels , vocab_size : int , num_items_in_batch : int = None , ignore_index : int = - 100 , ** kwargs
95
+ logits , labels , vocab_size , num_items_in_batch = None , ignore_index = - 100 , ** kwargs
93
96
):
94
97
if labels is None : return None
95
98
shift_logits = logits
@@ -159,7 +162,7 @@ def fused_linear_cross_entropy(
159
162
hidden_states : torch .Tensor ,
160
163
lm_weight : torch .Tensor ,
161
164
labels : torch .Tensor ,
162
- num_items_in_batch : int = None ,
165
+ num_items_in_batch : int | torch . Tensor = None ,
163
166
ignore_index : int = - 100 ,
164
167
reduction : str = "mean" ,
165
168
logit_softcapping : float = 0 ,
@@ -179,7 +182,10 @@ def fused_linear_cross_entropy(
179
182
shift = True ,
180
183
filter_eps = accuracy_threshold ,
181
184
)
182
- if num_items_in_batch is not None : loss = loss / num_items_in_batch
185
+ if num_items_in_batch is not None :
186
+ if torch .is_tensor (num_items_in_batch ):
187
+ num_items_in_batch = num_items_in_batch .to (loss .device )
188
+ loss = loss / num_items_in_batch
183
189
return loss
184
190
pass
185
191
@@ -188,7 +194,7 @@ def fast_linear_cross_entropy(
188
194
hidden_states : torch .Tensor ,
189
195
lm_head : torch .nn .Linear ,
190
196
labels : torch .Tensor ,
191
- num_items_in_batch : int = None ,
197
+ num_items_in_batch : int | torch . Tensor = None ,
192
198
ignore_index : int = - 100 ,
193
199
reduction : str = "mean" ,
194
200
logit_softcapping : float = 0 ,
@@ -218,7 +224,10 @@ def fast_linear_cross_entropy(
218
224
chunk_size = 512 ,
219
225
attention_mask = attention_mask ,
220
226
)
221
- if num_items_in_batch is not None : loss = loss / num_items_in_batch
227
+ if num_items_in_batch is not None :
228
+ if torch .is_tensor (num_items_in_batch ):
229
+ num_items_in_batch = num_items_in_batch .to (loss .device )
230
+ loss = loss / num_items_in_batch
222
231
return loss
223
232
pass
224
233
@@ -292,8 +301,6 @@ def _unsloth_get_batch_samples(self, epoch_iterator, num_batches, device = None,
292
301
293
302
if self .args .average_tokens_across_devices :
294
303
num_items_in_batch = self .accelerator .gather (num_items_in_batch ).sum ()
295
- if device is not None and torch .is_tensor (num_items_in_batch ):
296
- num_items_in_batch = num_items_in_batch .to (device )
297
304
except Exception as exception :
298
305
raise RuntimeError (exception )
299
306
pass
0 commit comments