Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
224 changes: 190 additions & 34 deletions verl/utils/torch_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,25 @@
NPU_CROSS_ENTROPY_LOSS_AVAILABLE = False


def gather_from_labels(data, label):
"""Gather the label from data. The value in label should be [0, vocab_size)
def gather_from_labels(data: torch.Tensor, label: torch.Tensor) -> torch.Tensor:
"""Gather values from data tensor at positions specified by label indices.

Selects elements from the last dimension of `data` based on indices in `label`.
Commonly used to extract log-probabilities for specific token IDs from a
vocabulary distribution.

Args:
data: (..., vocab_size)
label (torch.IntTensor) : (...,)
data: Input tensor of shape (..., vocab_size) containing values to gather from.
label: Index tensor of shape (...,) with values in range [0, vocab_size).

Returns:
torch.Tensor: Gathered values with shape (...,), same as label shape.

Example:
>>> logits = torch.randn(2, 3, 100) # [batch, seq, vocab]
>>> labels = torch.randint(0, 100, (2, 3)) # [batch, seq]
>>> gathered = gather_from_labels(logits, labels) # [batch, seq]
"""

output = torch.gather(data, -1, label.unsqueeze(-1)).squeeze(-1)
return output

Expand Down Expand Up @@ -92,30 +100,89 @@ def logprobs_from_logits(logits, labels, inplace_backward=True):
return output


def logprobs_from_logits_flash_attn(logits, labels, inplace_backward=True):
def logprobs_from_logits_flash_attn(
logits: torch.Tensor, labels: torch.Tensor, inplace_backward: bool = True
) -> torch.Tensor:
"""Compute log-probabilities using Flash Attention's optimized cross-entropy.

Uses the Flash Attention library's Triton-based cross-entropy implementation
for efficient computation on NVIDIA GPUs.

Args:
logits: Model output logits of shape (batch_size, vocab_size).
labels: Target token indices of shape (batch_size,).
inplace_backward: If True, perform backward pass in-place for memory efficiency.

Returns:
torch.Tensor: Log-probabilities for target labels, shape (batch_size,).

Raises:
AssertionError: If flash-attn version < 2.4.3 (different return format).
"""
output = cross_entropy_loss(logits, labels, inplace_backward=inplace_backward)
assert isinstance(output, tuple), (
"please make sure flash-attn>=2.4.3 where cross_entropy_loss returns Tuple[losses, z_losses]."
)
return -output[0]


def logprobs_from_logits_torch_npu(logits, labels):
def logprobs_from_logits_torch_npu(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
"""Compute log-probabilities using Ascend NPU's optimized cross-entropy.

Uses torch_npu's native cross-entropy implementation for efficient
computation on Huawei Ascend NPU devices.

Args:
logits: Model output logits of shape (..., vocab_size).
labels: Target token indices of shape (...,).

Returns:
torch.Tensor: Log-probabilities for target labels, same shape as labels.
"""
batch_dim = logits.shape[:-1]
logits = logits.reshape(-1, logits.shape[-1])
loss, _, _, _ = torch_npu.npu_cross_entropy_loss(logits, labels.reshape(-1), reduction="none")
return -loss.view(*batch_dim)


def logprobs_from_logits_naive(logits, labels):
def logprobs_from_logits_naive(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
"""Compute log-probabilities using standard log-softmax approach.

Simple implementation using PyTorch's log_softmax followed by gathering.
Less memory-efficient than specialized implementations but works on all devices.

Args:
logits: Model output logits of shape (..., vocab_size).
labels: Target token indices of shape (...,).

Returns:
torch.Tensor: Log-probabilities for target labels, same shape as labels.
"""
logp = F.log_softmax(logits, dim=-1)
logpy = gather_from_labels(logp, labels)
return logpy


def logprobs_from_logits_v2(logits: torch.FloatTensor, labels):
"""
A memory efficient implementation of logprobs_from_logits
def logprobs_from_logits_v2(logits: torch.FloatTensor, labels: torch.Tensor) -> torch.Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The type hint torch.FloatTensor for logits is too restrictive. The function body explicitly handles other dtypes like bfloat16 in the else block. Using torch.Tensor would be more accurate and align with modern PyTorch practices, as torch.FloatTensor is a legacy alias for a torch.float32 tensor.

Suggested change
def logprobs_from_logits_v2(logits: torch.FloatTensor, labels: torch.Tensor) -> torch.Tensor:
def logprobs_from_logits_v2(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:

"""Memory-efficient log-probability computation using row-wise processing.

Computes log-probabilities by processing one row at a time to reduce peak
memory consumption. Uses logsumexp for float32/float64, falls back to
log_softmax for bfloat16 due to numerical stability concerns.

The mathematical identity used is: log_softmax(x_i) = x_i - logsumexp(x)

Args:
logits: Model output logits of shape (batch_size, seq_len, vocab_size)
or (batch_size, vocab_size).
labels: Target token indices matching logits shape without vocab dimension.

Returns:
torch.Tensor: Log-probabilities for target labels.

Note:
This implementation trades compute for memory by iterating over batch
dimension, making it suitable for large vocabulary sizes.
"""
if logits.dtype in [torch.float32, torch.float64]:
logits_labels = torch.gather(logits, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
Expand All @@ -133,24 +200,62 @@ def logprobs_from_logits_v2(logits: torch.FloatTensor, labels):
return logprobs_labels


def clip_by_value(x, tensor_min, tensor_max):
"""
Tensor extenstion to torch.clamp
https://github.com/pytorch/pytorch/issues/2793#issuecomment-428784713
def clip_by_value(
x: torch.Tensor, tensor_min: torch.Tensor, tensor_max: torch.Tensor
) -> torch.Tensor:
"""Clip tensor values to a range defined by tensor bounds.

Extension of torch.clamp that supports tensor-valued min/max bounds
instead of only scalar bounds.

Args:
x: Input tensor to clip.
tensor_min: Minimum bound tensor (broadcastable to x).
tensor_max: Maximum bound tensor (broadcastable to x).

Returns:
torch.Tensor: Clipped tensor with values in [tensor_min, tensor_max].

See Also:
https://github.com/pytorch/pytorch/issues/2793#issuecomment-428784713
"""
clipped = torch.max(torch.min(x, tensor_max), tensor_min)
return clipped


def entropy_from_logits(logits: torch.Tensor):
"""Calculate entropy from logits."""
def entropy_from_logits(logits: torch.Tensor) -> torch.Tensor:
"""Calculate Shannon entropy from unnormalized logits.

Computes H(p) = -sum(p * log(p)) using the numerically stable formula:
entropy = logsumexp(logits) - sum(softmax(logits) * logits)

Args:
logits: Unnormalized log-probabilities of shape (..., vocab_size).

Returns:
torch.Tensor: Entropy values with shape (...,), one per distribution.
"""
pd = torch.nn.functional.softmax(logits, dim=-1)
entropy = torch.logsumexp(logits, dim=-1) - torch.sum(pd * logits, dim=-1)
return entropy


def entropy_from_logits_with_chunking(logits: torch.Tensor, chunk_size: int = 2048):
"""Memory-efficient entropy calculation with chunking."""
def entropy_from_logits_with_chunking(logits: torch.Tensor, chunk_size: int = 2048) -> torch.Tensor:
"""Memory-efficient entropy calculation using chunked processing.

Computes entropy by processing the batch in chunks to reduce peak memory
usage. Useful for large batch sizes or when memory is constrained.

Args:
logits: Unnormalized log-probabilities of shape (batch_size, vocab_size).
chunk_size: Number of samples to process at once. Defaults to 2048.

Returns:
torch.Tensor: Entropy values with shape (batch_size,).

Note:
Converts chunks to float32 for numerical stability during computation.
"""
entropy = torch.zeros(logits.shape[0], device=logits.device)
for i in range(0, logits.shape[0], chunk_size):
logits_chunk = logits[i : i + chunk_size].float()
Expand All @@ -160,8 +265,23 @@ def entropy_from_logits_with_chunking(logits: torch.Tensor, chunk_size: int = 20
return entropy


def masked_sum(values, mask, axis=None):
"""Compute mean of tensor with a masked values."""
def masked_sum(
values: torch.Tensor, mask: torch.Tensor, axis: int | tuple[int, ...] | None = None
) -> torch.Tensor:
"""Compute sum of tensor values where mask is True.

NaN values outside the mask are replaced with zeros to prevent
contaminating the sum.

Args:
values: Input tensor containing values to sum.
mask: Boolean or numeric mask tensor (same shape as values).
Non-zero values indicate elements to include.
axis: Dimension(s) along which to sum. None sums all elements.

Returns:
torch.Tensor: Sum of masked values, reduced along specified axis.
"""
# If NaNs exist out of mask, replace NaNs in values with a value that
# won't affect the sum (e.g., 0 for masked regions)
valid_values = torch.where(mask.bool(), values, 0.0)
Expand Down Expand Up @@ -246,35 +366,71 @@ def get_response_mask(response_id: torch.Tensor, eos_token: int | list[int] = 2,
return (eos_mask.cumsum(dim=1) - eos_mask).eq(0).to(dtype)


def compute_grad_norm(model: nn.Module):
def compute_grad_norm(model: nn.Module) -> float:
"""Compute the squared L2 norm of all gradients in a model.

Sums the squared values of all gradient tensors across all parameters.
Useful for monitoring gradient magnitudes during training.

Args:
model: PyTorch model with computed gradients.

Returns:
float: Sum of squared gradient values (not the square root).

Note:
Returns the squared norm, not the norm itself. To get the actual
L2 norm, take the square root of the returned value.
"""
total_grad_square = 0
for param in model.parameters():
if param.grad is not None:
total_grad_square += torch.sum(torch.square(param.grad.detach())).item()
return total_grad_square


def broadcast_dict_tensor(tensors: dict[str, torch.Tensor] | TensorDict, src, group):
"""
TODO: optimize this. Technically, we only need one broadcast
"""
def broadcast_dict_tensor(
tensors: dict[str, torch.Tensor] | TensorDict, src: int, group
) -> None:
"""Broadcast all tensors in a dictionary from source rank to all ranks.

Iterates over all tensors in the dictionary and broadcasts each one
from the source rank to all other ranks in the process group.

Args:
tensors: Dictionary or TensorDict containing tensors to broadcast.
src: Source rank from which to broadcast.
group: Process group for the broadcast operation.

Note:
This implementation broadcasts tensors one at a time. Could be optimized
to use a single broadcast with packed tensors.
"""
for key in tensors.sorted_keys:
torch.distributed.broadcast(tensors[key], src=src, group=group, async_op=False)


def allgather_dict_tensors(tensors: dict[str, torch.Tensor] | TensorDict, size, group, dim=0):
"""
TODO: optimize this.
- We can use async ops
- We can use only one allgather
def allgather_dict_tensors(
tensors: dict[str, torch.Tensor] | TensorDict, size: int, group, dim: int = 0
) -> dict[str, torch.Tensor] | TensorDict:
"""Gather tensors from all ranks and concatenate them.

Performs all_gather on each tensor in the dictionary and concatenates
the results along the specified dimension.

Args:
tensors:
size:
group:
tensors: Dictionary or TensorDict containing tensors to gather.
size: Number of ranks in the process group.
group: Process group for the all_gather operation.
dim: Dimension along which to concatenate gathered tensors. Defaults to 0.

Returns:
Dictionary or TensorDict (matching input type) with gathered and
concatenated tensors. Each tensor's size along `dim` is multiplied by `size`.

Note:
This implementation gathers tensors one at a time synchronously.
Could be optimized using async ops or packed all_gather.
"""
if isinstance(tensors, TensorDict):
is_tensor_dict = True
Expand Down
Loading