@@ -355,7 +355,9 @@ def process(
355355 token_id = token_list [i ] if i < len (token_list ) else None
356356 extra = None
357357 if hidden_states_dict is not None :
358- if "_single" in hidden_states_dict :
358+ if "_full" in hidden_states_dict :
359+ extra = {"hidden_states" : hidden_states_dict ["_full" ]}
360+ elif "_single" in hidden_states_dict :
359361 extra = {"hidden_states" : hidden_states_dict ["_single" ][i ]}
360362 else :
361363 per_req = {}
@@ -382,10 +384,20 @@ def _extract_hidden_states(
382384 """Extract hidden states from model output or side-channel.
383385
384386 Priority:
385- 1. Side-channel (_captured_aux_hidden_states) from hidden capture hooks
386- 2. logits_output.hidden_states (legacy single-tensor path)
387+ 1. Full-sequence side-channel (_captured_full_hidden_states) for
388+ image_gen prefill-only — preserves the full [seq_len, hidden_dim]
389+ tensor so downstream can apply gen_mask.
390+ 2. Side-channel (_captured_aux_hidden_states) from hidden capture hooks
391+ 3. logits_output.hidden_states (legacy single-tensor path)
387392 """
388- # Check side-channel first (set by _hidden_capture hooks)
393+ # Full-sequence capture (set by BailingMoeV2ForCausalLM.forward)
394+ if self ._model is not None :
395+ full_hs = getattr (self ._model , "_captured_full_hidden_states" , None )
396+ if full_hs is not None :
397+ self ._model ._captured_full_hidden_states = None
398+ return {"_full" : full_hs }
399+
400+ # Side-channel from _hidden_capture hooks
389401 if self ._model is not None and self ._capture_hidden_layers :
390402 aux = getattr (self ._model , "_captured_aux_hidden_states" , None )
391403 if aux is not None :
@@ -471,14 +483,18 @@ def update_request(self, request: SchedulerRequest, output: RequestOutput) -> No
471483 )
472484
473485 if req .is_chunked > 0 :
486+ # Accumulate full-sequence hidden states across chunks so
487+ # image_gen prefill-only gets the complete [seq_len, hidden_dim].
488+ if output .extra :
489+ self ._accumulate_hidden_states (data , output .extra )
474490 output .data = None
475491 req .is_chunked -= 1
476492 return
477493
478494 # Transfer captured model outputs (e.g. hidden states) to the
479495 # request data so they're available to downstream pipeline stages.
480496 if output .extra :
481- data . extra_model_outputs . update ( output .extra )
497+ self . _accumulate_hidden_states ( data , output .extra )
482498
483499 token_id = output .data
484500 if token_id is not None :
@@ -499,6 +515,22 @@ def update_request(self, request: SchedulerRequest, output: RequestOutput) -> No
499515 req .finished (),
500516 )
501517
518+ @staticmethod
519+ def _accumulate_hidden_states (data , extra : dict ) -> None :
520+ """Merge extra into data.extra_model_outputs, concatenating hidden_states tensors."""
521+ hs = extra .get ("hidden_states" )
522+ if hs is not None and isinstance (hs , torch .Tensor ):
523+ prev = data .extra_model_outputs .get ("hidden_states" )
524+ if prev is not None and isinstance (prev , torch .Tensor ):
525+ data .extra_model_outputs ["hidden_states" ] = torch .cat ([prev , hs ], dim = 0 )
526+ else :
527+ data .extra_model_outputs ["hidden_states" ] = hs
528+ rest = {k : v for k , v in extra .items () if k != "hidden_states" }
529+ if rest :
530+ data .extra_model_outputs .update (rest )
531+ else :
532+ data .extra_model_outputs .update (extra )
533+
502534 def is_finished (self , request : SchedulerRequest , output : RequestOutput ) -> bool :
503535 return request .data .req .finished ()
504536
0 commit comments