Skip to content

Convert n_items type from torch.Tensor to int #139

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions unsloth_zoo/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,7 @@ def _compiled_loss_function(
logit_scale_divide : float = 0,
logit_softcapping : float = 0,
vocab_size : int = 0,
n_items : int = 0,
n_items : int | torch.Tensor = 0,
):
device = output_logits.device
if logit_scale_multiply != 0:
Expand Down Expand Up @@ -658,6 +658,8 @@ def _compiled_loss_function(
)
pass
if n_items != 0:
if torch.is_tensor(n_items):
n_items = n_items.to(loss.device)
loss = loss / n_items
else:
loss = loss / (shift_labels != -100).sum()
Expand Down Expand Up @@ -743,7 +745,7 @@ def _compiled_loss_function(
logit_scale_divide : float = 0,
logit_softcapping : float = 0,
vocab_size : int = 0,
n_items : int = 0,
n_items : int | torch.Tensor = 0,
):
device = output_logits.device
if logit_scale_multiply != 0:
Expand Down Expand Up @@ -778,6 +780,8 @@ def _compiled_loss_function(
)
pass
if n_items != 0:
if torch.is_tensor(n_items):
n_items = n_items.to(loss.device)
loss = loss / n_items
else:
loss = loss / (shift_labels != -100).sum()
Expand Down Expand Up @@ -851,7 +855,7 @@ def _compiled_loss_function(
logit_scale_divide : float = 0,
logit_softcapping : float = 0,
vocab_size : int = 0,
n_items : int = 0,
n_items : int | torch.Tensor = 0,
):
device = output_logits.device
if logit_scale_multiply != 0:
Expand Down Expand Up @@ -886,6 +890,8 @@ def _compiled_loss_function(
)
pass
if n_items != 0:
if torch.is_tensor(n_items):
n_items = n_items.to(loss.device)
loss = loss / n_items
else:
loss = loss / (shift_labels != -100).sum()
Expand Down
23 changes: 15 additions & 8 deletions unsloth_zoo/loss_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,16 @@ def unsloth_fixed_cross_entropy(source, target, num_items_in_batch: int = None,
ignore_index = ignore_index,
reduction = reduction,
)
if reduction == "sum": loss = loss / num_items_in_batch
if reduction == "sum":
if torch.is_tensor(num_items_in_batch):
num_items_in_batch = num_items_in_batch.to(loss.device)
loss = loss / num_items_in_batch
return loss
pass

# Causal LM loss
def UnslothForCausalLMLoss(
logits, labels, vocab_size: int, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs
logits, labels, vocab_size, num_items_in_batch= None, ignore_index = -100, **kwargs
):
if labels is None: return None
shift_logits = logits
Expand Down Expand Up @@ -159,7 +162,7 @@ def fused_linear_cross_entropy(
hidden_states : torch.Tensor,
lm_weight : torch.Tensor,
labels : torch.Tensor,
num_items_in_batch : int = None,
num_items_in_batch : int | torch.Tensor = None,
ignore_index : int = -100,
reduction : str = "mean",
logit_softcapping : float = 0,
Expand All @@ -179,7 +182,10 @@ def fused_linear_cross_entropy(
shift = True,
filter_eps = accuracy_threshold,
)
if num_items_in_batch is not None: loss = loss / num_items_in_batch
if num_items_in_batch is not None:
if torch.is_tensor(num_items_in_batch):
num_items_in_batch = num_items_in_batch.to(loss.device)
loss = loss / num_items_in_batch
return loss
pass

Expand All @@ -188,7 +194,7 @@ def fast_linear_cross_entropy(
hidden_states : torch.Tensor,
lm_head : torch.nn.Linear,
labels : torch.Tensor,
num_items_in_batch : int = None,
num_items_in_batch : int | torch.Tensor = None,
ignore_index : int = -100,
reduction : str = "mean",
logit_softcapping : float = 0,
Expand Down Expand Up @@ -218,7 +224,10 @@ def fast_linear_cross_entropy(
chunk_size = 512,
attention_mask = attention_mask,
)
if num_items_in_batch is not None: loss = loss / num_items_in_batch
if num_items_in_batch is not None:
if torch.is_tensor(num_items_in_batch):
num_items_in_batch = num_items_in_batch.to(loss.device)
loss = loss / num_items_in_batch
return loss
pass

Expand Down Expand Up @@ -292,8 +301,6 @@ def _unsloth_get_batch_samples(self, epoch_iterator, num_batches, device = None,

if self.args.average_tokens_across_devices:
num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum()
if device is not None and torch.is_tensor(num_items_in_batch):
num_items_in_batch = num_items_in_batch.to(device)
except Exception as exception:
raise RuntimeError(exception)
pass
Expand Down