@@ -10,6 +10,7 @@ def roll_packed_tensor(
1010 cu_seq_lens : torch .IntTensor ,
1111 shifts : int = - 1 ,
1212 dim : int = - 1 ,
13+ fill_value : float | int = 0 ,
1314) -> torch .Tensor :
1415 """Roll a packed tensor along the specified dimension.
1516
@@ -24,9 +25,12 @@ def roll_packed_tensor(
2425 Only negative shifts are supported.
2526 dim (int): Dimension along which to roll. The ``cu_seq_lens`` boundaries
2627 are applied on this dimension. Default is -1 (last dimension).
28+ fill_value (float | int): Value used to fill boundary positions after rolling.
29+ Defaults to 0. Use the loss ignore index (e.g., -100) when rolling label
30+ tensors to ensure boundary positions are excluded from loss computation.
2731
2832 Returns:
29- torch.Tensor: Rolled tensor with boundary positions zeroed .
33+ torch.Tensor: Rolled tensor with boundary positions filled with ``fill_value`` .
3034
3135 Example:
3236 For packed sequences [1,2,3] and [4,5,6] with shifts=-1, dim=-1:
@@ -39,7 +43,7 @@ def roll_packed_tensor(
3943 >>> tensor = torch.arange(12).reshape(1, 6, 2)
4044 >>> cu_seq_lens = torch.tensor([0, 3, 6], dtype=torch.int32)
4145 >>> rolled = roll_packed_tensor(tensor, cu_seq_lens, shifts=-1, dim=-2)
42- >>> rolled[0, 2] # tensor([0, 0]) (boundary zeroed )
46+ >>> rolled[0, 2] # tensor([0, 0]) (boundary filled with fill_value=0 )
4347 """
4448 assert shifts <= 0 , "Only negative shift is supported"
4549
@@ -57,13 +61,13 @@ def roll_packed_tensor(
5761 seq_slice = tensor .narrow (dim , start_idx , end_idx - start_idx ) # type: ignore[arg-type]
5862 rolled_seq = torch .roll (seq_slice , shifts = shifts , dims = dim )
5963
60- # Zero out the last |shifts| positions along dim to avoid information
64+ # Fill the last |shifts| positions along dim to avoid information
6165 # leakage across sequences. For shifts=-1 the last 1 position is
62- # zeroed ; for shifts=-2 the last 2 positions are zeroed , etc.
63- zero_len = - shifts
64- zero_start = (end_idx - start_idx ) - zero_len
65- zero_slice = rolled_seq .narrow (dim , zero_start , zero_len ) # type: ignore[arg-type]
66- zero_slice . zero_ ( )
66+ # filled ; for shifts=-2 the last 2 positions are filled , etc.
67+ fill_len = - shifts
68+ fill_start = (end_idx - start_idx ) - fill_len
69+ fill_slice = rolled_seq .narrow (dim , fill_start , fill_len ) # type: ignore[arg-type]
70+ fill_slice . fill_ ( fill_value )
6771
6872 # Write back to the rolled tensor
6973 rolled_tensor .narrow (dim , start_idx , end_idx - start_idx ).copy_ (rolled_seq ) # type: ignore[arg-type]
0 commit comments