Skip to content

Commit d8bad22

Browse files
author
--global
committed
Convert num_items_in_batch to int also in loss_utils.py
1 parent 4739189 commit d8bad22

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

unsloth_zoo/loss_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,8 @@ def _unsloth_get_batch_samples(self, epoch_iterator, num_batches, device = None,
294294
num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum()
295295
if device is not None and torch.is_tensor(num_items_in_batch):
296296
num_items_in_batch = num_items_in_batch.to(device)
297+
num_items_in_batch = num_items_in_batch.item() if isinstance(num_items_in_batch, torch.Tensor) else num_items_in_batch
298+
num_items_in_batch = int(num_items_in_batch) if num_items_in_batch is not None else None
297299
except Exception as exception:
298300
raise RuntimeError(exception)
299301
pass

0 commit comments

Comments
 (0)