Skip to content

Commit 72cfd51

Browse files
author
luoyuan.luo
committed
Refactor semantic conditioner to reuse thinker in large scale
1 parent a169243 commit 72cfd51

7 files changed

Lines changed: 802 additions & 60 deletions

File tree

sglang_omni/engines/omni/runtime/sglang_ar.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,11 @@ def update_request(self, request: SchedulerRequest, output: RequestOutput) -> No
475475
req.is_chunked -= 1
476476
return
477477

478+
# Transfer captured model outputs (e.g. hidden states) to the
479+
# request data so they're available to downstream pipeline stages.
480+
if output.extra:
481+
data.extra_model_outputs.update(output.extra)
482+
478483
token_id = output.data
479484
if token_id is not None:
480485
req.output_ids.append(token_id)

sglang_omni/models/ming_omni/components/image_gen_executor.py

Lines changed: 286 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22
"""Image generation executor for Ming-Omni.
33
44
Wraps a DiffusionBackend (SD3 or Z-Image) as a pipeline Executor stage.
5-
The executor receives decoded text from the thinker and generates images
6-
using its own self-contained diffusion pipeline (text encoder + DiT + VAE).
5+
Supports two conditioning modes:
76
8-
This is the Phase 1 implementation using text conditioning.
9-
Phase 2 will replace text input with LLM hidden-state conditioning.
7+
1. **Hidden-state conditioning** (Phase 2) -- when the thinker provides
8+
hidden_states + gen_mask, a :class:`SemanticConditioner` projects them
9+
into condition embeddings for the diffusion model.
10+
2. **Text-only conditioning** (Phase 1 fallback) -- the executor decodes
11+
thinker output_ids to text and uses the backend's built-in text encoder.
1012
"""
1113

1214
from __future__ import annotations
@@ -49,11 +51,15 @@ def __init__(
4951
dit_type: str = "zimage",
5052
dit_model_path: str | None = None,
5153
device: str = "cuda",
54+
conditioner=None, # SemanticConditioner instance (or None for text-only)
55+
skip_semantic_encoder: bool = False,
5256
):
5357
self._model_path = model_path
5458
self._dit_type = dit_type
5559
self._dit_model_path = dit_model_path or model_path
5660
self._device = device
61+
self._conditioner = conditioner
62+
self._skip_semantic_encoder = skip_semantic_encoder
5763

5864
self._backend: DiffusionBackend | None = None
5965
self._thinker_tokenizer = None
@@ -75,7 +81,19 @@ def _load_models(self) -> None:
7581
"""Load diffusion backend + thinker tokenizer (runs in thread pool)."""
7682
t0 = time.time()
7783
self._backend = _create_backend(self._dit_type)
78-
self._backend.load_models(self._dit_model_path, torch.device(self._device))
84+
85+
# skip_semantic_encoder is only supported by ZImageBackend. Other
86+
# backends (e.g. SD3) use a simpler load_models(path, device) API.
87+
if self._dit_type == "zimage" and self._skip_semantic_encoder:
88+
self._backend.load_models(
89+
self._dit_model_path,
90+
torch.device(self._device),
91+
skip_semantic_encoder=True,
92+
)
93+
else:
94+
self._backend.load_models(
95+
self._dit_model_path, torch.device(self._device)
96+
)
7997
logger.info("[IMG_GEN] Diffusion backend loaded in %.1fs", time.time() - t0)
8098

8199
# Load thinker tokenizer for decoding output_ids → text prompt
@@ -91,52 +109,126 @@ def _load_models(self) -> None:
91109
logger.warning("[IMG_GEN] Could not load thinker tokenizer: %s", e)
92110

93111
async def add_request(self, payload: StagePayload) -> None:
94-
"""Process an image generation request."""
112+
"""Process an image generation request.
113+
114+
Two conditioning paths are attempted in order:
115+
116+
1. **Hidden-state conditioning** -- if a :class:`SemanticConditioner`
117+
is configured and the payload contains ``hidden_states`` +
118+
``gen_mask`` from the thinker, project them into condition
119+
embeddings and pass directly to the diffusion backend.
120+
2. **Text-only conditioning** (fallback) -- decode thinker
121+
``output_ids`` to text and let the backend's built-in text
122+
encoder produce the condition embeddings.
123+
"""
95124
request_id = payload.request_id
96125
if request_id in self._aborted:
97126
return
98127

99-
text, params = self._extract_input(payload)
100-
logger.info(
101-
"[IMG_GEN] prompt (len=%d): %r, size=%dx%d, steps=%d",
102-
len(text) if text else 0,
103-
text[:200] if text else "",
104-
params.width,
105-
params.height,
106-
params.num_inference_steps,
107-
)
128+
data = payload.data
129+
if not isinstance(data, dict):
130+
data = {}
108131

109-
if not text:
110-
result = StagePayload(
111-
request_id=request_id,
112-
request=payload.request,
113-
data={"image_data": None, "modality": "image"},
132+
# ------------------------------------------------------------------
133+
# Try hidden-state conditioning (Phase 2)
134+
# ------------------------------------------------------------------
135+
condition_embeds = None
136+
negative_embeds = None
137+
138+
if self._conditioner is not None:
139+
condition_embeds, negative_embeds = self._try_condition_from_hidden_states(
140+
data
114141
)
115-
await self._results.put(result)
116-
return
117142

118-
t0 = time.time()
119-
logger.info("[IMG_GEN] Starting image generation...")
120-
try:
121-
image = await asyncio.to_thread(self._generate_image, text, params)
122-
elapsed = time.time() - t0
143+
if condition_embeds is not None:
144+
# Hidden-state conditioning path
145+
params = self._extract_params(data, payload.request)
146+
prompt_text = self._extract_text_for_byt5(data)
123147
logger.info(
124-
"[IMG_GEN] Image generated in %.1fs (%dx%d)",
125-
elapsed,
126-
image.width,
127-
image.height,
128-
)
129-
except Exception as e:
130-
logger.error(
131-
"[IMG_GEN] ERROR after %.1fs: %s", time.time() - t0, e, exc_info=True
148+
"[IMG_GEN] Using hidden-state conditioning, size=%dx%d, steps=%d",
149+
params.width,
150+
params.height,
151+
params.num_inference_steps,
132152
)
133-
result = StagePayload(
134-
request_id=request_id,
135-
request=payload.request,
136-
data={"image_data": None, "modality": "image", "error": str(e)},
153+
154+
t0 = time.time()
155+
try:
156+
image = await asyncio.to_thread(
157+
self._generate_with_condition_embeds,
158+
prompt_text,
159+
params,
160+
condition_embeds,
161+
negative_embeds,
162+
)
163+
elapsed = time.time() - t0
164+
logger.info(
165+
"[IMG_GEN] Image generated in %.1fs (%dx%d)",
166+
elapsed,
167+
image.width,
168+
image.height,
169+
)
170+
except Exception as e:
171+
logger.error(
172+
"[IMG_GEN] ERROR after %.1fs: %s",
173+
time.time() - t0,
174+
e,
175+
exc_info=True,
176+
)
177+
result = StagePayload(
178+
request_id=request_id,
179+
request=payload.request,
180+
data={"image_data": None, "modality": "image", "error": str(e)},
181+
)
182+
await self._results.put(result)
183+
return
184+
else:
185+
# ------------------------------------------------------------------
186+
# Fallback: text-only conditioning (Phase 1)
187+
# ------------------------------------------------------------------
188+
text, params = self._extract_input(payload)
189+
logger.info(
190+
"[IMG_GEN] prompt (len=%d): %r, size=%dx%d, steps=%d",
191+
len(text) if text else 0,
192+
text[:200] if text else "",
193+
params.width,
194+
params.height,
195+
params.num_inference_steps,
137196
)
138-
await self._results.put(result)
139-
return
197+
198+
if not text:
199+
result = StagePayload(
200+
request_id=request_id,
201+
request=payload.request,
202+
data={"image_data": None, "modality": "image"},
203+
)
204+
await self._results.put(result)
205+
return
206+
207+
t0 = time.time()
208+
logger.info("[IMG_GEN] Starting image generation...")
209+
try:
210+
image = await asyncio.to_thread(self._generate_image, text, params)
211+
elapsed = time.time() - t0
212+
logger.info(
213+
"[IMG_GEN] Image generated in %.1fs (%dx%d)",
214+
elapsed,
215+
image.width,
216+
image.height,
217+
)
218+
except Exception as e:
219+
logger.error(
220+
"[IMG_GEN] ERROR after %.1fs: %s",
221+
time.time() - t0,
222+
e,
223+
exc_info=True,
224+
)
225+
result = StagePayload(
226+
request_id=request_id,
227+
request=payload.request,
228+
data={"image_data": None, "modality": "image", "error": str(e)},
229+
)
230+
await self._results.put(result)
231+
return
140232

141233
# Serialize image to PNG bytes for cross-process msgpack transport
142234
buf = io.BytesIO()
@@ -225,6 +317,159 @@ def _extract_input(self, payload: StagePayload) -> tuple[str, ImageGenParams]:
225317
)
226318
return text, params
227319

320+
# ------------------------------------------------------------------
321+
# Hidden-state conditioning helpers
322+
# ------------------------------------------------------------------
323+
324+
def _try_condition_from_hidden_states(
325+
self, data: dict
326+
) -> tuple[list[torch.Tensor] | None, list[torch.Tensor] | None]:
327+
"""Try to build condition embeddings from thinker hidden states.
328+
329+
Returns ``(condition_embeds, negative_embeds)`` as lists of tensors
330+
(one per batch element), or ``(None, None)`` if the required fields
331+
are missing from *data*.
332+
"""
333+
thinker_out = data.get("thinker_out", {})
334+
if not isinstance(thinker_out, dict):
335+
return None, None
336+
337+
extra = thinker_out.get("extra_model_outputs", {})
338+
if not isinstance(extra, dict):
339+
return None, None
340+
341+
# Hidden states captured by SGLang during the thinker forward pass
342+
hidden_states = extra.get("hidden_states")
343+
if hidden_states is None:
344+
return None, None
345+
346+
# gen_mask from mm_inputs (set by the preprocessor)
347+
mm_inputs = data.get("mm_inputs", {})
348+
image_gen = mm_inputs.get("image_gen", {})
349+
gen_mask_list = image_gen.get("gen_mask")
350+
if gen_mask_list is None:
351+
return None, None
352+
353+
# Resolve hidden states to a single tensor
354+
if isinstance(hidden_states, dict):
355+
# Side-channel capture: pick the last (highest) layer
356+
numeric_keys = [
357+
k for k in hidden_states
358+
if isinstance(k, int) or (isinstance(k, str) and k.isdigit())
359+
]
360+
if not numeric_keys:
361+
return None, None
362+
last_key = max(numeric_keys, key=lambda k: int(k))
363+
hs = hidden_states[last_key]
364+
elif isinstance(hidden_states, torch.Tensor):
365+
hs = hidden_states
366+
else:
367+
return None, None
368+
369+
# Extract query token positions using gen_mask
370+
gen_mask = torch.tensor(gen_mask_list, dtype=torch.bool, device=hs.device)
371+
if hs.dim() == 2:
372+
# [seq_len, hidden_dim] -> [1, num_query, hidden_dim]
373+
query_hidden = hs[gen_mask].unsqueeze(0)
374+
elif hs.dim() == 3:
375+
# [batch, seq_len, hidden_dim]
376+
query_hidden = hs[:, gen_mask, :]
377+
else:
378+
logger.warning(
379+
"[IMG_GEN] Unexpected hidden_states dim=%d, skipping", hs.dim()
380+
)
381+
return None, None
382+
383+
logger.info(
384+
"[IMG_GEN] Projecting hidden states %s through conditioner",
385+
list(query_hidden.shape),
386+
)
387+
388+
# Project through the conditioner: [B, N, 4096] -> [B, N, 2560]
389+
condition_embeds = self._conditioner.project(query_hidden)
390+
negative_embeds = condition_embeds * 0.0
391+
392+
# Convert to list format expected by DiffusionBackend.generate()
393+
pos_list = list(condition_embeds.unbind(dim=0))
394+
neg_list = list(negative_embeds.unbind(dim=0))
395+
396+
return pos_list, neg_list
397+
398+
def _extract_params(self, data: dict, request) -> ImageGenParams:
399+
"""Extract image generation parameters without text.
400+
401+
Used by the hidden-state conditioning path where text is not the
402+
primary conditioning signal.
403+
"""
404+
raw_inputs = data.get("raw_inputs")
405+
img_params_dict: dict = {}
406+
if isinstance(raw_inputs, dict):
407+
img_params_dict = raw_inputs.get("image_generation", {})
408+
409+
# Also check mm_inputs (set by the preprocessor)
410+
if not img_params_dict:
411+
mm_inputs = data.get("mm_inputs", {})
412+
image_gen = mm_inputs.get("image_gen", {})
413+
img_params_dict = image_gen.get("image_gen_params", {})
414+
415+
if not img_params_dict and request is not None:
416+
metadata = getattr(request, "metadata", {}) or {}
417+
img_params_dict = metadata.get("image_generation", {})
418+
419+
# Parse size string like "1024x1024"
420+
width = img_params_dict.get("width", 1024)
421+
height = img_params_dict.get("height", 1024)
422+
size = img_params_dict.get("size")
423+
if isinstance(size, str) and "x" in size:
424+
parts = size.split("x")
425+
try:
426+
width, height = int(parts[0]), int(parts[1])
427+
except ValueError:
428+
pass
429+
430+
return ImageGenParams(
431+
width=width,
432+
height=height,
433+
num_inference_steps=img_params_dict.get("num_inference_steps", 28),
434+
guidance_scale=img_params_dict.get("guidance_scale", 7.0),
435+
seed=img_params_dict.get("seed"),
436+
negative_prompt=img_params_dict.get("negative_prompt", ""),
437+
)
438+
439+
def _extract_text_for_byt5(self, data: dict) -> str:
440+
"""Extract original prompt text for ByT5 text rendering.
441+
442+
When using hidden-state conditioning, the prompt text is still
443+
passed to generate() so that ZImageBackend can extract quoted
444+
text for ByT5 encoding if available.
445+
"""
446+
prompt = data.get("prompt", {})
447+
if isinstance(prompt, dict):
448+
return prompt.get("prompt_text", "")
449+
return ""
450+
451+
@torch.no_grad()
452+
def _generate_with_condition_embeds(
453+
self,
454+
prompt_text: str,
455+
params: ImageGenParams,
456+
condition_embeds: list[torch.Tensor],
457+
negative_embeds: list[torch.Tensor],
458+
):
459+
"""Run the diffusion pipeline with pre-computed condition embeddings."""
460+
if self._backend is None:
461+
raise RuntimeError("Diffusion backend not loaded")
462+
return self._backend.generate(
463+
prompt_text or "",
464+
params,
465+
condition_embeds=condition_embeds,
466+
negative_condition_embeds=negative_embeds,
467+
)
468+
469+
# ------------------------------------------------------------------
470+
# Text-only generation
471+
# ------------------------------------------------------------------
472+
228473
@torch.no_grad()
229474
def _generate_image(self, text: str, params: ImageGenParams):
230475
"""Run the diffusion pipeline (called in thread pool)."""

0 commit comments

Comments
 (0)