@@ -267,12 +267,12 @@ def _prepare_prompt(
267267 dtype = torch .long ,
268268 device = self .device )
269269
270- logger .info ("[RBLN] model input builder, prepare_prompt" )
271- logger .info ("\t padded input_tokens = %s" , input_tokens )
272- logger .info ("\t padded input_positions = %s" , input_positions )
273- logger .info ("\t input_block_ids = %s" , input_block_ids )
274- logger .info ("\t seq_lens = %s" , data .seq_lens )
275- logger .info ("\t query_lens = %s" , data .query_lens )
270+ logger .debug ("[RBLN] model input builder, prepare_prompt" )
271+ logger .debug ("\t padded input_tokens = %s" , input_tokens )
272+ logger .debug ("\t padded input_positions = %s" , input_positions )
273+ logger .debug ("\t input_block_ids = %s" , input_block_ids )
274+ logger .debug ("\t seq_lens = %s" , data .seq_lens )
275+ logger .debug ("\t query_lens = %s" , data .query_lens )
276276 return (input_tokens , input_positions , input_block_ids )
277277
278278 def _prepare_decode (
@@ -340,12 +340,12 @@ def _prepare_decode(
340340 dtype = torch .long ,
341341 device = self .device )
342342
343- logger .info ("[RBLN] model input builder, prepare_decode" )
344- logger .info ("\t padded input_tokens = %s" , data .input_tokens )
345- logger .info ("\t padded input_positions = %s" , data .input_positions )
346- logger .info ("\t input_block_ids = %s" , input_block_ids )
347- logger .info ("\t seq_lens = %s" , data .seq_lens )
348- logger .info ("\t query_lens = %s" , data .query_lens )
343+ logger .debug ("[RBLN] model input builder, prepare_decode" )
344+ logger .debug ("\t padded input_tokens = %s" , data .input_tokens )
345+ logger .debug ("\t padded input_positions = %s" , data .input_positions )
346+ logger .debug ("\t input_block_ids = %s" , input_block_ids )
347+ logger .debug ("\t seq_lens = %s" , data .seq_lens )
348+ logger .debug ("\t query_lens = %s" , data .query_lens )
349349
350350 assert input_tokens .shape [0 ] == self .max_num_seqs
351351 assert input_positions .shape [0 ] == self .max_num_seqs
@@ -520,10 +520,10 @@ def model_wrapper(
520520 model_output = model_output [:, selected_token_indices ]
521521 logits = self .compute_logits_model .compute_logits (
522522 model_output , None )
523+ return logits
523524 else :
524525 # non last rank create intermediate tensors, bypass it
525- logits = model_output
526- return logits
526+ return model_output
527527
528528 if self .model_config .enforce_eager or not envs .RBLN_COMPILE_MODEL :
529529 self .model_executable = model_wrapper
@@ -583,9 +583,9 @@ def prepare_model_input(
583583
584584 is_prompt = seq_group_metadata_list [
585585 0 ].is_prompt if seq_group_metadata_list else None
586- logger .info ("[RBLN] num_requests = %d" , len (seq_group_metadata_list ))
587- logger .info ("[RBLN] input_ids = %s" , model_input .input_tokens )
588- logger .info ("[RBLN] positions = %s" , model_input .input_positions )
586+ logger .debug ("[RBLN] num_requests = %d" , len (seq_group_metadata_list ))
587+ logger .debug ("[RBLN] input_ids = %s" , model_input .input_tokens )
588+ logger .debug ("[RBLN] positions = %s" , model_input .input_positions )
589589 return dataclasses .replace (model_input ,
590590 sampling_metadata = sampling_metadata ,
591591 virtual_engine = virtual_engine ,
@@ -594,12 +594,12 @@ def prepare_model_input(
594594 @torch .inference_mode ()
595595 def execute_model (
596596 self ,
597- model_input : ModelInputForRebel ,
597+ model_input : ModelInputForRebelWithSamplingMetadata ,
598598 kv_caches : Optional [List [torch .Tensor ]] = None ,
599599 intermediate_tensors : Optional [IntermediateTensors ] = None ,
600600 num_steps : int = 1 ,
601601 previous_hidden_states : Optional [torch .Tensor ] = None ,
602- ) -> Optional [SamplerOutput ]:
602+ ) -> Optional [Union [ List [ SamplerOutput ], IntermediateTensors ] ]:
603603 assert kv_caches is not None
604604 if num_steps > 1 :
605605 raise ValueError (
@@ -613,6 +613,7 @@ def execute_model(
613613 assert model_input .attn_metadata is not None
614614 token_indices = None
615615 if get_pp_group ().is_last_rank :
616+ assert model_input .sampling_metadata is not None
616617 num_prefills = model_input .attn_metadata .num_prefills
617618 selected_token_indices = \
618619 model_input .sampling_metadata .selected_token_indices
@@ -633,30 +634,29 @@ def execute_model(
633634 if model_input .attn_metadata is not None :
634635 model_input .attn_metadata .kv_caches = kv_caches
635636
636- hidden_states = self .model_executable (
637+ logits_or_intermediate_states = self .model_executable (
637638 input_ids = model_input .input_tokens ,
638639 positions = model_input .input_positions ,
639640 intermediate_tensors = intermediate_tensors ,
640641 selected_token_indices = token_indices ,
641642 ** execute_model_kwargs ,
642643 )
643644
644- if get_pp_group ().is_last_rank :
645- # Gather logits for TP
646- logits_processor = self .compute_logits_model .logits_processor
647- hidden_states = logits_processor ._gather_logits (hidden_states )
648- hidden_states = hidden_states .view (- 1 , hidden_states .size (- 1 ))
645+ if get_pp_group ().is_last_rank :
646+ # Gather logits for TP
647+ logits_processor = self .compute_logits_model .logits_processor
648+ logits = logits_processor ._gather_logits (
649+ logits_or_intermediate_states )
650+ logits = logits .view (- 1 , logits .size (- 1 ))
649651
650- if not get_pp_group (). is_last_rank :
651- intermediate_states = hidden_states
652+ else :
653+ intermediate_states = logits_or_intermediate_states
652654 assert isinstance (intermediate_states , IntermediateTensors )
653655 return intermediate_states
654656
655657 # Compute the logits. -> moved to model executable
656- if num_prefills > 0 and len_token_indices != 0 :
657- logits = hidden_states
658- else :
659- logits = hidden_states [selected_token_indices ]
658+ if not (num_prefills > 0 and len_token_indices != 0 ):
659+ logits = logits [selected_token_indices ]
660660
661661 # Only perform sampling in the driver worker.
662662 if not self .is_driver_worker :
0 commit comments