Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions verl/utils/seqlen_balancing.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def karmarkar_karp(seqlen_list: list[int], k_partitions: int, equal_size: bool)
Note:
When equal_size=True, len(seqlen_list) must be divisible by k_partitions.
"""

# see: https://en.wikipedia.org/wiki/Largest_differencing_method
class Set:
def __init__(self) -> None:
Expand Down
12 changes: 3 additions & 9 deletions verl/utils/torch_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,7 @@ def logprobs_from_logits_v2(logits: torch.FloatTensor, labels: torch.Tensor) ->
return logprobs_labels


def clip_by_value(
x: torch.Tensor, tensor_min: torch.Tensor, tensor_max: torch.Tensor
) -> torch.Tensor:
def clip_by_value(x: torch.Tensor, tensor_min: torch.Tensor, tensor_max: torch.Tensor) -> torch.Tensor:
"""Clip tensor values to a range defined by tensor bounds.

Extension of torch.clamp that supports tensor-valued min/max bounds
Expand Down Expand Up @@ -265,9 +263,7 @@ def entropy_from_logits_with_chunking(logits: torch.Tensor, chunk_size: int = 20
return entropy


def masked_sum(
values: torch.Tensor, mask: torch.Tensor, axis: int | tuple[int, ...] | None = None
) -> torch.Tensor:
def masked_sum(values: torch.Tensor, mask: torch.Tensor, axis: int | tuple[int, ...] | None = None) -> torch.Tensor:
"""Compute sum of tensor values where mask is True.

NaN values outside the mask are replaced with zeros to prevent
Expand Down Expand Up @@ -389,9 +385,7 @@ def compute_grad_norm(model: nn.Module) -> float:
return total_grad_square


def broadcast_dict_tensor(
tensors: dict[str, torch.Tensor] | TensorDict, src: int, group
) -> None:
def broadcast_dict_tensor(tensors: dict[str, torch.Tensor] | TensorDict, src: int, group) -> None:
"""Broadcast all tensors in a dictionary from source rank to all ranks.

Iterates over all tensors in the dictionary and broadcasts each one
Expand Down
5 changes: 4 additions & 1 deletion verl/workers/engine/mindspeed/transformer_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
import logging
import os

from mindspeed.megatron_adaptor import repatch
try:
from mindspeed.megatron_adaptor import repatch
except ImportError:
repatch = None

from verl.trainer.config import CheckpointConfig
from verl.workers.config import HFModelConfig, McoreEngineConfig, McoreOptimizerConfig
Expand Down
2 changes: 1 addition & 1 deletion verl/workers/megatron_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from omegaconf import DictConfig, OmegaConf

try:
from mindspeed.megatron_adaptor import repatch
from verl.workers.engine.mindspeed.transformer_impl import repatch
except ImportError:
repatch = None
Comment on lines 30 to 33
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The try...except ImportError block is now redundant. The verl.workers.engine.mindspeed.transformer_impl module handles the case where mindspeed is not installed by setting repatch to None. Therefore, this import will no longer raise an ImportError under normal circumstances, and this try...except block can be removed to improve code clarity.

from verl.workers.engine.mindspeed.transformer_impl import repatch


Expand Down
Loading