22"""Image generation executor for Ming-Omni.
33
44Wraps a DiffusionBackend (SD3 or Z-Image) as a pipeline Executor stage.
5- The executor receives decoded text from the thinker and generates images
6- using its own self-contained diffusion pipeline (text encoder + DiT + VAE).
5+ Supports two conditioning modes:
76
8- This is the Phase 1 implementation using text conditioning.
9- Phase 2 will replace text input with LLM hidden-state conditioning.
7+ 1. **Hidden-state conditioning** (Phase 2) -- when the thinker provides
8+ hidden_states + gen_mask, a :class:`SemanticConditioner` projects them
9+ into condition embeddings for the diffusion model.
10+ 2. **Text-only conditioning** (Phase 1 fallback) -- the executor decodes
11+ thinker output_ids to text and uses the backend's built-in text encoder.
1012"""
1113
1214from __future__ import annotations
@@ -49,11 +51,15 @@ def __init__(
4951 dit_type : str = "zimage" ,
5052 dit_model_path : str | None = None ,
5153 device : str = "cuda" ,
54+ conditioner = None , # SemanticConditioner instance (or None for text-only)
55+ skip_semantic_encoder : bool = False ,
5256 ):
5357 self ._model_path = model_path
5458 self ._dit_type = dit_type
5559 self ._dit_model_path = dit_model_path or model_path
5660 self ._device = device
61+ self ._conditioner = conditioner
62+ self ._skip_semantic_encoder = skip_semantic_encoder
5763
5864 self ._backend : DiffusionBackend | None = None
5965 self ._thinker_tokenizer = None
@@ -75,7 +81,19 @@ def _load_models(self) -> None:
7581 """Load diffusion backend + thinker tokenizer (runs in thread pool)."""
7682 t0 = time .time ()
7783 self ._backend = _create_backend (self ._dit_type )
78- self ._backend .load_models (self ._dit_model_path , torch .device (self ._device ))
84+
85+ # skip_semantic_encoder is only supported by ZImageBackend. Other
86+ # backends (e.g. SD3) use a simpler load_models(path, device) API.
87+ if self ._dit_type == "zimage" and self ._skip_semantic_encoder :
88+ self ._backend .load_models (
89+ self ._dit_model_path ,
90+ torch .device (self ._device ),
91+ skip_semantic_encoder = True ,
92+ )
93+ else :
94+ self ._backend .load_models (
95+ self ._dit_model_path , torch .device (self ._device )
96+ )
7997 logger .info ("[IMG_GEN] Diffusion backend loaded in %.1fs" , time .time () - t0 )
8098
8199 # Load thinker tokenizer for decoding output_ids → text prompt
@@ -91,52 +109,126 @@ def _load_models(self) -> None:
91109 logger .warning ("[IMG_GEN] Could not load thinker tokenizer: %s" , e )
92110
93111 async def add_request (self , payload : StagePayload ) -> None :
94- """Process an image generation request."""
112+ """Process an image generation request.
113+
114+ Two conditioning paths are attempted in order:
115+
116+ 1. **Hidden-state conditioning** -- if a :class:`SemanticConditioner`
117+ is configured and the payload contains ``hidden_states`` +
118+ ``gen_mask`` from the thinker, project them into condition
119+ embeddings and pass directly to the diffusion backend.
120+ 2. **Text-only conditioning** (fallback) -- decode thinker
121+ ``output_ids`` to text and let the backend's built-in text
122+ encoder produce the condition embeddings.
123+ """
95124 request_id = payload .request_id
96125 if request_id in self ._aborted :
97126 return
98127
99- text , params = self ._extract_input (payload )
100- logger .info (
101- "[IMG_GEN] prompt (len=%d): %r, size=%dx%d, steps=%d" ,
102- len (text ) if text else 0 ,
103- text [:200 ] if text else "" ,
104- params .width ,
105- params .height ,
106- params .num_inference_steps ,
107- )
128+ data = payload .data
129+ if not isinstance (data , dict ):
130+ data = {}
108131
109- if not text :
110- result = StagePayload (
111- request_id = request_id ,
112- request = payload .request ,
113- data = {"image_data" : None , "modality" : "image" },
132+ # ------------------------------------------------------------------
133+ # Try hidden-state conditioning (Phase 2)
134+ # ------------------------------------------------------------------
135+ condition_embeds = None
136+ negative_embeds = None
137+
138+ if self ._conditioner is not None :
139+ condition_embeds , negative_embeds = self ._try_condition_from_hidden_states (
140+ data
114141 )
115- await self ._results .put (result )
116- return
117142
118- t0 = time .time ()
119- logger .info ("[IMG_GEN] Starting image generation..." )
120- try :
121- image = await asyncio .to_thread (self ._generate_image , text , params )
122- elapsed = time .time () - t0
143+ if condition_embeds is not None :
144+ # Hidden-state conditioning path
145+ params = self ._extract_params (data , payload .request )
146+ prompt_text = self ._extract_text_for_byt5 (data )
123147 logger .info (
124- "[IMG_GEN] Image generated in %.1fs (%dx%d)" ,
125- elapsed ,
126- image .width ,
127- image .height ,
128- )
129- except Exception as e :
130- logger .error (
131- "[IMG_GEN] ERROR after %.1fs: %s" , time .time () - t0 , e , exc_info = True
148+ "[IMG_GEN] Using hidden-state conditioning, size=%dx%d, steps=%d" ,
149+ params .width ,
150+ params .height ,
151+ params .num_inference_steps ,
132152 )
133- result = StagePayload (
134- request_id = request_id ,
135- request = payload .request ,
136- data = {"image_data" : None , "modality" : "image" , "error" : str (e )},
153+
154+ t0 = time .time ()
155+ try :
156+ image = await asyncio .to_thread (
157+ self ._generate_with_condition_embeds ,
158+ prompt_text ,
159+ params ,
160+ condition_embeds ,
161+ negative_embeds ,
162+ )
163+ elapsed = time .time () - t0
164+ logger .info (
165+ "[IMG_GEN] Image generated in %.1fs (%dx%d)" ,
166+ elapsed ,
167+ image .width ,
168+ image .height ,
169+ )
170+ except Exception as e :
171+ logger .error (
172+ "[IMG_GEN] ERROR after %.1fs: %s" ,
173+ time .time () - t0 ,
174+ e ,
175+ exc_info = True ,
176+ )
177+ result = StagePayload (
178+ request_id = request_id ,
179+ request = payload .request ,
180+ data = {"image_data" : None , "modality" : "image" , "error" : str (e )},
181+ )
182+ await self ._results .put (result )
183+ return
184+ else :
185+ # ------------------------------------------------------------------
186+ # Fallback: text-only conditioning (Phase 1)
187+ # ------------------------------------------------------------------
188+ text , params = self ._extract_input (payload )
189+ logger .info (
190+ "[IMG_GEN] prompt (len=%d): %r, size=%dx%d, steps=%d" ,
191+ len (text ) if text else 0 ,
192+ text [:200 ] if text else "" ,
193+ params .width ,
194+ params .height ,
195+ params .num_inference_steps ,
137196 )
138- await self ._results .put (result )
139- return
197+
198+ if not text :
199+ result = StagePayload (
200+ request_id = request_id ,
201+ request = payload .request ,
202+ data = {"image_data" : None , "modality" : "image" },
203+ )
204+ await self ._results .put (result )
205+ return
206+
207+ t0 = time .time ()
208+ logger .info ("[IMG_GEN] Starting image generation..." )
209+ try :
210+ image = await asyncio .to_thread (self ._generate_image , text , params )
211+ elapsed = time .time () - t0
212+ logger .info (
213+ "[IMG_GEN] Image generated in %.1fs (%dx%d)" ,
214+ elapsed ,
215+ image .width ,
216+ image .height ,
217+ )
218+ except Exception as e :
219+ logger .error (
220+ "[IMG_GEN] ERROR after %.1fs: %s" ,
221+ time .time () - t0 ,
222+ e ,
223+ exc_info = True ,
224+ )
225+ result = StagePayload (
226+ request_id = request_id ,
227+ request = payload .request ,
228+ data = {"image_data" : None , "modality" : "image" , "error" : str (e )},
229+ )
230+ await self ._results .put (result )
231+ return
140232
141233 # Serialize image to PNG bytes for cross-process msgpack transport
142234 buf = io .BytesIO ()
@@ -225,6 +317,159 @@ def _extract_input(self, payload: StagePayload) -> tuple[str, ImageGenParams]:
225317 )
226318 return text , params
227319
320+ # ------------------------------------------------------------------
321+ # Hidden-state conditioning helpers
322+ # ------------------------------------------------------------------
323+
324+ def _try_condition_from_hidden_states (
325+ self , data : dict
326+ ) -> tuple [list [torch .Tensor ] | None , list [torch .Tensor ] | None ]:
327+ """Try to build condition embeddings from thinker hidden states.
328+
329+ Returns ``(condition_embeds, negative_embeds)`` as lists of tensors
330+ (one per batch element), or ``(None, None)`` if the required fields
331+ are missing from *data*.
332+ """
333+ thinker_out = data .get ("thinker_out" , {})
334+ if not isinstance (thinker_out , dict ):
335+ return None , None
336+
337+ extra = thinker_out .get ("extra_model_outputs" , {})
338+ if not isinstance (extra , dict ):
339+ return None , None
340+
341+ # Hidden states captured by SGLang during the thinker forward pass
342+ hidden_states = extra .get ("hidden_states" )
343+ if hidden_states is None :
344+ return None , None
345+
346+ # gen_mask from mm_inputs (set by the preprocessor)
347+ mm_inputs = data .get ("mm_inputs" , {})
348+ image_gen = mm_inputs .get ("image_gen" , {})
349+ gen_mask_list = image_gen .get ("gen_mask" )
350+ if gen_mask_list is None :
351+ return None , None
352+
353+ # Resolve hidden states to a single tensor
354+ if isinstance (hidden_states , dict ):
355+ # Side-channel capture: pick the last (highest) layer
356+ numeric_keys = [
357+ k for k in hidden_states
358+ if isinstance (k , int ) or (isinstance (k , str ) and k .isdigit ())
359+ ]
360+ if not numeric_keys :
361+ return None , None
362+ last_key = max (numeric_keys , key = lambda k : int (k ))
363+ hs = hidden_states [last_key ]
364+ elif isinstance (hidden_states , torch .Tensor ):
365+ hs = hidden_states
366+ else :
367+ return None , None
368+
369+ # Extract query token positions using gen_mask
370+ gen_mask = torch .tensor (gen_mask_list , dtype = torch .bool , device = hs .device )
371+ if hs .dim () == 2 :
372+ # [seq_len, hidden_dim] -> [1, num_query, hidden_dim]
373+ query_hidden = hs [gen_mask ].unsqueeze (0 )
374+ elif hs .dim () == 3 :
375+ # [batch, seq_len, hidden_dim]
376+ query_hidden = hs [:, gen_mask , :]
377+ else :
378+ logger .warning (
379+ "[IMG_GEN] Unexpected hidden_states dim=%d, skipping" , hs .dim ()
380+ )
381+ return None , None
382+
383+ logger .info (
384+ "[IMG_GEN] Projecting hidden states %s through conditioner" ,
385+ list (query_hidden .shape ),
386+ )
387+
388+ # Project through the conditioner: [B, N, 4096] -> [B, N, 2560]
389+ condition_embeds = self ._conditioner .project (query_hidden )
390+ negative_embeds = condition_embeds * 0.0
391+
392+ # Convert to list format expected by DiffusionBackend.generate()
393+ pos_list = list (condition_embeds .unbind (dim = 0 ))
394+ neg_list = list (negative_embeds .unbind (dim = 0 ))
395+
396+ return pos_list , neg_list
397+
398+ def _extract_params (self , data : dict , request ) -> ImageGenParams :
399+ """Extract image generation parameters without text.
400+
401+ Used by the hidden-state conditioning path where text is not the
402+ primary conditioning signal.
403+ """
404+ raw_inputs = data .get ("raw_inputs" )
405+ img_params_dict : dict = {}
406+ if isinstance (raw_inputs , dict ):
407+ img_params_dict = raw_inputs .get ("image_generation" , {})
408+
409+ # Also check mm_inputs (set by the preprocessor)
410+ if not img_params_dict :
411+ mm_inputs = data .get ("mm_inputs" , {})
412+ image_gen = mm_inputs .get ("image_gen" , {})
413+ img_params_dict = image_gen .get ("image_gen_params" , {})
414+
415+ if not img_params_dict and request is not None :
416+ metadata = getattr (request , "metadata" , {}) or {}
417+ img_params_dict = metadata .get ("image_generation" , {})
418+
419+ # Parse size string like "1024x1024"
420+ width = img_params_dict .get ("width" , 1024 )
421+ height = img_params_dict .get ("height" , 1024 )
422+ size = img_params_dict .get ("size" )
423+ if isinstance (size , str ) and "x" in size :
424+ parts = size .split ("x" )
425+ try :
426+ width , height = int (parts [0 ]), int (parts [1 ])
427+ except ValueError :
428+ pass
429+
430+ return ImageGenParams (
431+ width = width ,
432+ height = height ,
433+ num_inference_steps = img_params_dict .get ("num_inference_steps" , 28 ),
434+ guidance_scale = img_params_dict .get ("guidance_scale" , 7.0 ),
435+ seed = img_params_dict .get ("seed" ),
436+ negative_prompt = img_params_dict .get ("negative_prompt" , "" ),
437+ )
438+
439+ def _extract_text_for_byt5 (self , data : dict ) -> str :
440+ """Extract original prompt text for ByT5 text rendering.
441+
442+ When using hidden-state conditioning, the prompt text is still
443+ passed to generate() so that ZImageBackend can extract quoted
444+ text for ByT5 encoding if available.
445+ """
446+ prompt = data .get ("prompt" , {})
447+ if isinstance (prompt , dict ):
448+ return prompt .get ("prompt_text" , "" )
449+ return ""
450+
451+ @torch .no_grad ()
452+ def _generate_with_condition_embeds (
453+ self ,
454+ prompt_text : str ,
455+ params : ImageGenParams ,
456+ condition_embeds : list [torch .Tensor ],
457+ negative_embeds : list [torch .Tensor ],
458+ ):
459+ """Run the diffusion pipeline with pre-computed condition embeddings."""
460+ if self ._backend is None :
461+ raise RuntimeError ("Diffusion backend not loaded" )
462+ return self ._backend .generate (
463+ prompt_text or "" ,
464+ params ,
465+ condition_embeds = condition_embeds ,
466+ negative_condition_embeds = negative_embeds ,
467+ )
468+
469+ # ------------------------------------------------------------------
470+ # Text-only generation
471+ # ------------------------------------------------------------------
472+
228473 @torch .no_grad ()
229474 def _generate_image (self , text : str , params : ImageGenParams ):
230475 """Run the diffusion pipeline (called in thread pool)."""
0 commit comments