@@ -200,9 +200,7 @@ def logprobs_from_logits_v2(logits: torch.FloatTensor, labels: torch.Tensor) ->
200200 return logprobs_labels
201201
202202
203- def clip_by_value (
204- x : torch .Tensor , tensor_min : torch .Tensor , tensor_max : torch .Tensor
205- ) -> torch .Tensor :
203+ def clip_by_value (x : torch .Tensor , tensor_min : torch .Tensor , tensor_max : torch .Tensor ) -> torch .Tensor :
206204 """Clip tensor values to a range defined by tensor bounds.
207205
208206 Extension of torch.clamp that supports tensor-valued min/max bounds
@@ -265,9 +263,7 @@ def entropy_from_logits_with_chunking(logits: torch.Tensor, chunk_size: int = 20
265263 return entropy
266264
267265
268- def masked_sum (
269- values : torch .Tensor , mask : torch .Tensor , axis : int | tuple [int , ...] | None = None
270- ) -> torch .Tensor :
266+ def masked_sum (values : torch .Tensor , mask : torch .Tensor , axis : int | tuple [int , ...] | None = None ) -> torch .Tensor :
271267 """Compute sum of tensor values where mask is True.
272268
273269 NaN values outside the mask are replaced with zeros to prevent
@@ -389,9 +385,7 @@ def compute_grad_norm(model: nn.Module) -> float:
389385 return total_grad_square
390386
391387
392- def broadcast_dict_tensor (
393- tensors : dict [str , torch .Tensor ] | TensorDict , src : int , group
394- ) -> None :
388+ def broadcast_dict_tensor (tensors : dict [str , torch .Tensor ] | TensorDict , src : int , group ) -> None :
395389 """Broadcast all tensors in a dictionary from source rank to all ranks.
396390
397391 Iterates over all tensors in the dictionary and broadcasts each one
0 commit comments