Skip to content

Convert n_items type from torch.Tensor to int #139

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

MoonRainy21
Copy link

Change the type of n_items from a tensor to an integer.
On compiler.py, type of n_items is shown as int, but the variable passed into the model was typed torch.Tensor.
It may not be issue in most of cases, but it may cause issues related to device. (such as n_items on cuda:0 and operate with loss on cuda:7)

@MoonRainy21 MoonRainy21 force-pushed the fix/num_batch_items-type branch 3 times, most recently from b87adbf to d8bad22 Compare May 15, 2025 08:04
@danielhanchen
Copy link
Contributor

Thanks for the PR! Would the previous line if device is not None and torch.is_tensor(num_items_in_batch): num_items_in_batch = num_items_in_batch.to(device) solve that issue?

@MoonRainy21
Copy link
Author

If there's only a single device, yes.
However, if I run with multi GPU (not officially supported though), this caused error since the num_items_in_batch occurs on loss_utils.py:182, where loss could be on cuda:7 while num_items_in_batch would be on cuda:0
Also, the function is expecting num_items_in_batch to be int!
Thank you

torch_cuda_device = torch.cuda.device
def fused_linear_cross_entropy(
    hidden_states      : torch.Tensor,
    lm_weight          : torch.Tensor,
    labels             : torch.Tensor,
    num_items_in_batch : int = None,
    ignore_index       : int = -100,
    reduction          : str = "mean",
    logit_softcapping  : float = 0,
    accuracy_threshold : str = "auto",
):
    # All Unsloth Zoo code licensed under LGPLv3
    reduction = "sum" if num_items_in_batch is not None else "mean"
    if logit_softcapping == 0: logit_softcapping = None
    with torch_cuda_device(lm_weight.device):
        loss = linear_cross_entropy(
            hidden_states.to(lm_weight.dtype),
            lm_weight,
            targets      = labels,
            ignore_index = ignore_index,
            softcap      = logit_softcapping,
            reduction    = reduction,
            shift        = True,
            filter_eps   = accuracy_threshold,
        )
    if num_items_in_batch is not None: loss = loss / num_items_in_batch
    return loss
pass

@MoonRainy21
Copy link
Author

I’ve tested multi-GPU training using my fork, and everything seems to be working well. @danielhanchen, do you have any additional comments or suggestions?

@danielhanchen
Copy link
Contributor

@MoonRainy21 Apologies on the delay - your PR is correct yes, but I'm worried this'll make training slower due to CPU->GPU communication. @Erland366 was working on seeing if we can remove this bottleneck.

@MoonRainy21
Copy link
Author

Then we might have to change the type of all num_items_in_batch into torch.Tensor and move its device to where loss is calculated or any other place it was used. Do you think that would work?

@danielhanchen
Copy link
Contributor

@MoonRainy21 I'm assuming if num_items_in_batch is not None: loss = loss / num_items_in_batch.to(loss.device) might work maybe?

@danielhanchen
Copy link
Contributor

You can ignore the num_items_in_batch : int = None,

Maybe better to do:

with torch_cuda_device(lm_weight.device):
        loss = linear_cross_entropy(
            hidden_states.to(lm_weight.dtype),
            lm_weight,
            targets      = labels,
            ignore_index = ignore_index,
            softcap      = logit_softcapping,
            reduction    = reduction,
            shift        = True,
            filter_eps   = accuracy_threshold,
        )
    if num_items_in_batch is not None:
        if torch.is_tensor(num_items_in_batch):
            num_items_in_batch = num_items_in_batch.to(loss.device)
        loss = loss / num_items_in_batch
    return loss

@Erland366
Copy link
Collaborator

I want to discuss abit about this

I tested the behavior in vanilla HuggingFace and it also got the same issue :

Here's my testing on Kaggle notebook on 2 T4 -> https://www.kaggle.com/code/erlandpg/test-multigpu-bitsandbytes

I tested to move the num_items_in_batch to the loss device and it works. but GPU utilization is around 20%. I need more testing whether this is a good number of utilization? Also whether we want to do this at all or not since HF itself did not support it.

cc: @danielhanchen

@MoonRainy21
Copy link
Author

It seems there's code for the device of num_items_in_batch on loss_utils.py:304. I wasn't able to understand when _unsloth_get_batch_samples is called, but it seems it would become unnecessary.

@MoonRainy21
Copy link
Author

@Erland366 For utilization, GPU utilization of the running GPU was pretty high for me (~80%) when I have tried with higher batch size but only one of the GPUs were running. We might need pipeline parallelism support for better utilization.

cc: @danielhanchen

@MoonRainy21 MoonRainy21 force-pushed the fix/num_batch_items-type branch from 4dbb0da to facf34d Compare May 26, 2025 00:19
@MoonRainy21
Copy link
Author

Have tested latest commit with 8 GPUs training Qwen3 235B A22B, but not really sure about some code which seems to be string and used for making unsloth compiled cache

@Erland366
Copy link
Collaborator

Erland366 commented May 27, 2025

@Erland366 For utilization, GPU utilization of the running GPU was pretty high for me (~80%) when I have tried with higher batch size but only one of the GPUs were running. We might need pipeline parallelism support for better utilization.

cc: @danielhanchen

Communication is really2 slow especially in my settings where communication is through PCIE and not NVLink nor Infiniband. If I only use 1 GPU, I got around 75% or so but if 2 GPUs, it drop down to 20%. I need to investigate what's the GPU utilization if I use NVLink system which based on this paper should be around 60% -> https://arxiv.org/abs/2505.12832v1

We cannot move forward into using integer since num_items_in_batch should be able to be called using all-gather if we. have different number of token across GPUs. This is actually created because of gradient accumulation bug found by Unsloth team (See this article -> https://muellerzr.github.io/blog/gradient_accumulation_part2.html#problem-distributed-training , you see that they called accelerator.gather on num_items_in_batch)

moving into loss.device seems the option here but I do not know if we're moving forward into that solution (since HF itself did not do that, perhaps for a reason)

@MoonRainy21
Copy link
Author

In case you are interested, my setting was 8 A100 SXM (nvlink connection) and performed around 50% to 70% when the tensor arrives. I'm trying to use FSDP to use pipeline parallelism in order to see the utilization on multi GPUs

@danielhanchen
Copy link
Contributor

Wait @Erland366 currently the code is:

torch_cuda_device = torch.cuda.device
def fused_linear_cross_entropy(
    hidden_states      : torch.Tensor,
    lm_weight          : torch.Tensor,
    labels             : torch.Tensor,
    num_items_in_batch : int = None,
    ignore_index       : int = -100,
    reduction          : str = "mean",
    logit_softcapping  : float = 0,
    accuracy_threshold : str = "auto",
):
    # All Unsloth Zoo code licensed under LGPLv3
    reduction = "sum" if num_items_in_batch is not None else "mean"
    if logit_softcapping == 0: logit_softcapping = None
    with torch_cuda_device(lm_weight.device):
        loss = linear_cross_entropy(
            hidden_states.to(lm_weight.dtype),
            lm_weight,
            targets      = labels,
            ignore_index = ignore_index,
            softcap      = logit_softcapping,
            reduction    = reduction,
            shift        = True,
            filter_eps   = accuracy_threshold,
        )
    if num_items_in_batch is not None: loss = loss / num_items_in_batch
    return loss
pass

are you saying you also managed to test:

torch_cuda_device = torch.cuda.device
def fused_linear_cross_entropy(
    hidden_states      : torch.Tensor,
    lm_weight          : torch.Tensor,
    labels             : torch.Tensor,
    num_items_in_batch : int = None,
    ignore_index       : int = -100,
    reduction          : str = "mean",
    logit_softcapping  : float = 0,
    accuracy_threshold : str = "auto",
):
    # All Unsloth Zoo code licensed under LGPLv3
    reduction = "sum" if num_items_in_batch is not None else "mean"
    if logit_softcapping == 0: logit_softcapping = None
    with torch_cuda_device(lm_weight.device):
        loss = linear_cross_entropy(
            hidden_states.to(lm_weight.dtype),
            lm_weight,
            targets      = labels,
            ignore_index = ignore_index,
            softcap      = logit_softcapping,
            reduction    = reduction,
            shift        = True,
            filter_eps   = accuracy_threshold,
        )
    if num_items_in_batch is not None:
        if torch.is_tensor(num_items_in_batch):
            num_items_in_batch = num_items_in_batch.to(loss.device)
        loss = loss / num_items_in_batch
    return loss
pass

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants