Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions verl/utils/seqlen_balancing.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def karmarkar_karp(seqlen_list: list[int], k_partitions: int, equal_size: bool)
Note:
When equal_size=True, len(seqlen_list) must be divisible by k_partitions.
"""

# see: https://en.wikipedia.org/wiki/Largest_differencing_method
class Set:
def __init__(self) -> None:
Expand Down
12 changes: 3 additions & 9 deletions verl/utils/torch_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,7 @@ def logprobs_from_logits_v2(logits: torch.FloatTensor, labels: torch.Tensor) ->
return logprobs_labels


def clip_by_value(
x: torch.Tensor, tensor_min: torch.Tensor, tensor_max: torch.Tensor
) -> torch.Tensor:
def clip_by_value(x: torch.Tensor, tensor_min: torch.Tensor, tensor_max: torch.Tensor) -> torch.Tensor:
"""Clip tensor values to a range defined by tensor bounds.
Extension of torch.clamp that supports tensor-valued min/max bounds
Expand Down Expand Up @@ -265,9 +263,7 @@ def entropy_from_logits_with_chunking(logits: torch.Tensor, chunk_size: int = 20
return entropy


def masked_sum(
values: torch.Tensor, mask: torch.Tensor, axis: int | tuple[int, ...] | None = None
) -> torch.Tensor:
def masked_sum(values: torch.Tensor, mask: torch.Tensor, axis: int | tuple[int, ...] | None = None) -> torch.Tensor:
"""Compute sum of tensor values where mask is True.
NaN values outside the mask are replaced with zeros to prevent
Expand Down Expand Up @@ -389,9 +385,7 @@ def compute_grad_norm(model: nn.Module) -> float:
return total_grad_square


def broadcast_dict_tensor(
tensors: dict[str, torch.Tensor] | TensorDict, src: int, group
) -> None:
def broadcast_dict_tensor(tensors: dict[str, torch.Tensor] | TensorDict, src: int, group) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The group parameter is missing a type hint. It's good practice to add type hints for all function parameters to improve code clarity and maintainability. Based on its usage with torch.distributed.broadcast, the type should be torch.distributed.ProcessGroup | None.

Suggested change
def broadcast_dict_tensor(tensors: dict[str, torch.Tensor] | TensorDict, src: int, group) -> None:
def broadcast_dict_tensor(tensors: dict[str, torch.Tensor] | TensorDict, src: int, group: torch.distributed.ProcessGroup | None) -> None:

"""Broadcast all tensors in a dictionary from source rank to all ranks.
Iterates over all tensors in the dictionary and broadcasts each one
Expand Down
Loading