Skip to content

Commit 9167859

Browse files
committed
format
1 parent 17a3e6c commit 9167859

File tree

2 files changed

+4
-9
lines changed

2 files changed

+4
-9
lines changed

verl/utils/seqlen_balancing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def karmarkar_karp(seqlen_list: list[int], k_partitions: int, equal_size: bool)
6868
Note:
6969
When equal_size=True, len(seqlen_list) must be divisible by k_partitions.
7070
"""
71+
7172
# see: https://en.wikipedia.org/wiki/Largest_differencing_method
7273
class Set:
7374
def __init__(self) -> None:

verl/utils/torch_functional.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -200,9 +200,7 @@ def logprobs_from_logits_v2(logits: torch.FloatTensor, labels: torch.Tensor) ->
200200
return logprobs_labels
201201

202202

203-
def clip_by_value(
204-
x: torch.Tensor, tensor_min: torch.Tensor, tensor_max: torch.Tensor
205-
) -> torch.Tensor:
203+
def clip_by_value(x: torch.Tensor, tensor_min: torch.Tensor, tensor_max: torch.Tensor) -> torch.Tensor:
206204
"""Clip tensor values to a range defined by tensor bounds.
207205
208206
Extension of torch.clamp that supports tensor-valued min/max bounds
@@ -265,9 +263,7 @@ def entropy_from_logits_with_chunking(logits: torch.Tensor, chunk_size: int = 20
265263
return entropy
266264

267265

268-
def masked_sum(
269-
values: torch.Tensor, mask: torch.Tensor, axis: int | tuple[int, ...] | None = None
270-
) -> torch.Tensor:
266+
def masked_sum(values: torch.Tensor, mask: torch.Tensor, axis: int | tuple[int, ...] | None = None) -> torch.Tensor:
271267
"""Compute sum of tensor values where mask is True.
272268
273269
NaN values outside the mask are replaced with zeros to prevent
@@ -389,9 +385,7 @@ def compute_grad_norm(model: nn.Module) -> float:
389385
return total_grad_square
390386

391387

392-
def broadcast_dict_tensor(
393-
tensors: dict[str, torch.Tensor] | TensorDict, src: int, group
394-
) -> None:
388+
def broadcast_dict_tensor(tensors: dict[str, torch.Tensor] | TensorDict, src: int, group) -> None:
395389
"""Broadcast all tensors in a dictionary from source rank to all ranks.
396390
397391
Iterates over all tensors in the dictionary and broadcasts each one

0 commit comments

Comments
 (0)