Skip to content

Commit c40099a

Browse files
authored
Merge pull request #9 from modal-labs/timmy/mrope-positions-on-gpu
remove mrope position sync
2 parents 1dd6ffa + 6264614 commit c40099a

File tree

1 file changed

+16
-17
lines changed

1 file changed

+16
-17
lines changed

python/sglang/srt/model_executor/forward_batch_info.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -513,24 +513,23 @@ def _compute_mrope_positions(
513513
for batch_idx in range(batch_size):
514514
mm_input = batch.multimodal_inputs[batch_idx]
515515
if self.forward_mode.is_decode():
516-
mrope_position_deltas = (
517-
[0]
518-
if mm_input is None
519-
else flatten_nested_list(mm_input.mrope_position_delta.tolist())
520-
)
521-
next_input_positions = []
522-
for mrope_position_delta in mrope_position_deltas:
523-
# batched deltas needs to be processed separately
524-
# Convert list of lists to tensor with shape [3, seq_len]
525-
next_input_positions += [
526-
MRotaryEmbedding.get_next_input_positions(
527-
mrope_position_delta,
528-
int(self.seq_lens[batch_idx]) - 1,
529-
int(self.seq_lens[batch_idx]),
530-
)
531-
]
532516
# 3 * N
533-
mrope_positions_list[batch_idx] = torch.cat(next_input_positions, dim=1)
517+
if mm_input is None:
518+
mrope_positions_list[batch_idx] = torch.full(
519+
(3, 1),
520+
self.seq_lens[batch_idx] - 1,
521+
dtype=torch.int64,
522+
device=model_runner.device,
523+
)
524+
else:
525+
mrope_position_deltas = (
526+
mm_input.mrope_position_delta
527+
.flatten()
528+
.to(model_runner.device, non_blocking=True)
529+
)
530+
mrope_positions_list[batch_idx] = (
531+
mrope_position_deltas + self.seq_lens[batch_idx] - 1
532+
).unsqueeze(0).repeat(3, 1)
534533
elif self.forward_mode.is_extend():
535534
extend_seq_len, extend_prefix_len = (
536535
batch.extend_seq_lens[batch_idx],

0 commit comments

Comments
 (0)