From 216730260d9ec18c35f276d1707dcb92a4724ae0 Mon Sep 17 00:00:00 2001 From: Kimrass Date: Fri, 23 May 2025 16:35:44 +0900 Subject: [PATCH] Debug: num_items_in_batch on a different device from loss. --- unsloth_zoo/loss_utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/loss_utils.py b/unsloth_zoo/loss_utils.py index edaf1f60..ebc80f50 100644 --- a/unsloth_zoo/loss_utils.py +++ b/unsloth_zoo/loss_utils.py @@ -21,6 +21,7 @@ from triton import __version__ as triton_version major, minor = torch.cuda.get_device_capability() import inspect +from typing import Union global HAS_CUT_CROSS_ENTROPY global UNSLOTH_STUDIO_ENABLED @@ -159,12 +160,15 @@ def fused_linear_cross_entropy( hidden_states : torch.Tensor, lm_weight : torch.Tensor, labels : torch.Tensor, - num_items_in_batch : int = None, + num_items_in_batch : Union[int, torch.Tensor] = None, ignore_index : int = -100, reduction : str = "mean", logit_softcapping : float = 0, accuracy_threshold : str = "auto", ): + if isinstance(num_items_in_batch, torch.Tensor): + num_items_in_batch = num_items_in_batch.detach().cpu().item() # `torch.Tensor` -> `int`. + # All Unsloth Zoo code licensed under LGPLv3 reduction = "sum" if num_items_in_batch is not None else "mean" if logit_softcapping == 0: logit_softcapping = None