Skip to content

Commit 1599573

Browse files
committed
revert
1 parent 190d7b7 commit 1599573

1 file changed

Lines changed: 3 additions & 7 deletions

File tree

swift/megatron/trainers/grpo_trainer.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -367,14 +367,10 @@ def _get_encoded_batch(self, encoded_list, rollout_batch, template):
367367
padding_to = cur_seq_len if template.padding_free else max_seq_len
368368
padding_len = padding_to - experts_seq_len
369369
if padding_len > 0:
370-
last_entry = routed_experts[-1:].expand(padding_len, -1, -1)
371-
padded_tail = torch.cat([routed_experts, last_entry], dim=0)
372370
padding_right = template.padding_side == 'right'
373-
if padding_right:
374-
padding_routed_experts = padded_tail
375-
else:
376-
left_pad = torch.zeros(padding_len, *routed_experts.shape[1:], dtype=routed_experts.dtype)
377-
padding_routed_experts = torch.cat([left_pad, padded_tail], dim=0)
371+
padding_routed_experts = nn.functional.pad(routed_experts,
372+
(0, 0, 0, 0, 0, padding_len) if padding_right else
373+
(0, 0, 0, 0, padding_len, 0), 'constant', 0)
378374
routed_experts_list.append(padding_routed_experts)
379375
if template.padding_free:
380376
global_routed_experts = torch.cat(routed_experts_list, dim=0).unsqueeze(0)

0 commit comments

Comments
 (0)