@@ -513,24 +513,23 @@ def _compute_mrope_positions(
513513 for batch_idx in range (batch_size ):
514514 mm_input = batch .multimodal_inputs [batch_idx ]
515515 if self .forward_mode .is_decode ():
516- mrope_position_deltas = (
517- [0 ]
518- if mm_input is None
519- else flatten_nested_list (mm_input .mrope_position_delta .tolist ())
520- )
521- next_input_positions = []
522- for mrope_position_delta in mrope_position_deltas :
523- # batched deltas needs to be processed separately
524- # Convert list of lists to tensor with shape [3, seq_len]
525- next_input_positions += [
526- MRotaryEmbedding .get_next_input_positions (
527- mrope_position_delta ,
528- int (self .seq_lens [batch_idx ]) - 1 ,
529- int (self .seq_lens [batch_idx ]),
530- )
531- ]
532516 # 3 * N
533- mrope_positions_list [batch_idx ] = torch .cat (next_input_positions , dim = 1 )
517+ if mm_input is None :
518+ mrope_positions_list [batch_idx ] = torch .full (
519+ (3 , 1 ),
520+ self .seq_lens [batch_idx ] - 1 ,
521+ dtype = torch .int64 ,
522+ device = model_runner .device ,
523+ )
524+ else :
525+ mrope_position_deltas = (
526+ mm_input .mrope_position_delta
527+ .flatten ()
528+ .to (model_runner .device , non_blocking = True )
529+ )
530+ mrope_positions_list [batch_idx ] = (
531+ mrope_position_deltas + self .seq_lens [batch_idx ] - 1
532+ ).unsqueeze (0 ).repeat (3 , 1 )
534533 elif self .forward_mode .is_extend ():
535534 extend_seq_len , extend_prefix_len = (
536535 batch .extend_seq_lens [batch_idx ],
0 commit comments