Skip to content

Commit 385cc26

Browse files
committed
Update
[ghstack-poisoned]
1 parent c148aac commit 385cc26

2 files changed

Lines changed: 13 additions & 9 deletions

File tree

xtuner/v1/model/moe/moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -695,7 +695,7 @@ def _forward(
695695
for idx, (mtp_hidden, mtp_ctx) in enumerate(zip(mtp_outputs, mtp_loss_ctx_list)):
696696
shifted_tensor = mtp_ctx.loss_kwargs.shifted_labels
697697
mtp_ctx.loss_kwargs.shifted_labels = roll_packed_tensor(
698-
shifted_tensor, seq_ctx.cu_seq_lens_k, -idx - 1, dim=-1
698+
shifted_tensor, seq_ctx.cu_seq_lens_k, -idx - 1, dim=-1, fill_value=-100
699699
)
700700

701701
mtp_hidden_states, mtp_router_results, mtp_router_weights = mtp_hidden

xtuner/v1/module/mtp/utils.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ def roll_packed_tensor(
1010
cu_seq_lens: torch.IntTensor,
1111
shifts: int = -1,
1212
dim: int = -1,
13+
fill_value: float | int = 0,
1314
) -> torch.Tensor:
1415
"""Roll a packed tensor along the specified dimension.
1516
@@ -24,9 +25,12 @@ def roll_packed_tensor(
2425
Only negative shifts are supported.
2526
dim (int): Dimension along which to roll. The ``cu_seq_lens`` boundaries
2627
are applied on this dimension. Default is -1 (last dimension).
28+
fill_value (float | int): Value used to fill boundary positions after rolling.
29+
Defaults to 0. Use the loss ignore index (e.g., -100) when rolling label
30+
tensors to ensure boundary positions are excluded from loss computation.
2731
2832
Returns:
29-
torch.Tensor: Rolled tensor with boundary positions zeroed.
33+
torch.Tensor: Rolled tensor with boundary positions filled with ``fill_value``.
3034
3135
Example:
3236
For packed sequences [1,2,3] and [4,5,6] with shifts=-1, dim=-1:
@@ -39,7 +43,7 @@ def roll_packed_tensor(
3943
>>> tensor = torch.arange(12).reshape(1, 6, 2)
4044
>>> cu_seq_lens = torch.tensor([0, 3, 6], dtype=torch.int32)
4145
>>> rolled = roll_packed_tensor(tensor, cu_seq_lens, shifts=-1, dim=-2)
42-
>>> rolled[0, 2] # tensor([0, 0]) (boundary zeroed)
46+
>>> rolled[0, 2] # tensor([0, 0]) (boundary filled with fill_value=0)
4347
"""
4448
assert shifts <= 0, "Only negative shift is supported"
4549

@@ -57,13 +61,13 @@ def roll_packed_tensor(
5761
seq_slice = tensor.narrow(dim, start_idx, end_idx - start_idx) # type: ignore[arg-type]
5862
rolled_seq = torch.roll(seq_slice, shifts=shifts, dims=dim)
5963

60-
# Zero out the last |shifts| positions along dim to avoid information
64+
# Fill the last |shifts| positions along dim to avoid information
6165
# leakage across sequences. For shifts=-1 the last 1 position is
62-
# zeroed; for shifts=-2 the last 2 positions are zeroed, etc.
63-
zero_len = -shifts
64-
zero_start = (end_idx - start_idx) - zero_len
65-
zero_slice = rolled_seq.narrow(dim, zero_start, zero_len) # type: ignore[arg-type]
66-
zero_slice.zero_()
66+
# filled; for shifts=-2 the last 2 positions are filled, etc.
67+
fill_len = -shifts
68+
fill_start = (end_idx - start_idx) - fill_len
69+
fill_slice = rolled_seq.narrow(dim, fill_start, fill_len) # type: ignore[arg-type]
70+
fill_slice.fill_(fill_value)
6771

6872
# Write back to the rolled tensor
6973
rolled_tensor.narrow(dim, start_idx, end_idx - start_idx).copy_(rolled_seq) # type: ignore[arg-type]

0 commit comments

Comments
 (0)