Skip to content

Commit 0ee9153

Browse files
author
luoyuan.luo
committed
Refactor sglang omni framework to support e2e diffusion
1 parent 72cfd51 commit 0ee9153

25 files changed

Lines changed: 1910 additions & 602 deletions

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ dependencies = [
5656
"torchaudio",
5757
# Gradio playground
5858
"gradio>=4.0.0",
59+
# Ming-Omni
60+
"diffusers==0.37.1",
5961
# flash-attn: install separately via prebuilt wheel (see install instructions below)
6062
]
6163

sglang_omni/engines/omni/factory.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,11 @@ def create_sglang_ar_engine(
330330
output_proc = SGLangOutputProcessor(
331331
capture_hidden=capture_hidden,
332332
capture_hidden_layers=capture_hidden_layers,
333-
model=model_worker.model_runner.model if capture_hidden_layers else None,
333+
model=(
334+
model_worker.model_runner.model
335+
if (capture_hidden or capture_hidden_layers)
336+
else None
337+
),
334338
)
335339

336340
if stream_adapter is None:

sglang_omni/engines/omni/runtime/sglang_ar.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,9 @@ def process(
355355
token_id = token_list[i] if i < len(token_list) else None
356356
extra = None
357357
if hidden_states_dict is not None:
358-
if "_single" in hidden_states_dict:
358+
if "_full" in hidden_states_dict:
359+
extra = {"hidden_states": hidden_states_dict["_full"]}
360+
elif "_single" in hidden_states_dict:
359361
extra = {"hidden_states": hidden_states_dict["_single"][i]}
360362
else:
361363
per_req = {}
@@ -382,10 +384,20 @@ def _extract_hidden_states(
382384
"""Extract hidden states from model output or side-channel.
383385
384386
Priority:
385-
1. Side-channel (_captured_aux_hidden_states) from hidden capture hooks
386-
2. logits_output.hidden_states (legacy single-tensor path)
387+
1. Full-sequence side-channel (_captured_full_hidden_states) for
388+
image_gen prefill-only — preserves the full [seq_len, hidden_dim]
389+
tensor so downstream can apply gen_mask.
390+
2. Side-channel (_captured_aux_hidden_states) from hidden capture hooks
391+
3. logits_output.hidden_states (legacy single-tensor path)
387392
"""
388-
# Check side-channel first (set by _hidden_capture hooks)
393+
# Full-sequence capture (set by BailingMoeV2ForCausalLM.forward)
394+
if self._model is not None:
395+
full_hs = getattr(self._model, "_captured_full_hidden_states", None)
396+
if full_hs is not None:
397+
self._model._captured_full_hidden_states = None
398+
return {"_full": full_hs}
399+
400+
# Side-channel from _hidden_capture hooks
389401
if self._model is not None and self._capture_hidden_layers:
390402
aux = getattr(self._model, "_captured_aux_hidden_states", None)
391403
if aux is not None:
@@ -471,14 +483,18 @@ def update_request(self, request: SchedulerRequest, output: RequestOutput) -> No
471483
)
472484

473485
if req.is_chunked > 0:
486+
# Accumulate full-sequence hidden states across chunks so
487+
# image_gen prefill-only gets the complete [seq_len, hidden_dim].
488+
if output.extra:
489+
self._accumulate_hidden_states(data, output.extra)
474490
output.data = None
475491
req.is_chunked -= 1
476492
return
477493

478494
# Transfer captured model outputs (e.g. hidden states) to the
479495
# request data so they're available to downstream pipeline stages.
480496
if output.extra:
481-
data.extra_model_outputs.update(output.extra)
497+
self._accumulate_hidden_states(data, output.extra)
482498

483499
token_id = output.data
484500
if token_id is not None:
@@ -499,6 +515,22 @@ def update_request(self, request: SchedulerRequest, output: RequestOutput) -> No
499515
req.finished(),
500516
)
501517

518+
@staticmethod
519+
def _accumulate_hidden_states(data, extra: dict) -> None:
520+
"""Merge extra into data.extra_model_outputs, concatenating hidden_states tensors."""
521+
hs = extra.get("hidden_states")
522+
if hs is not None and isinstance(hs, torch.Tensor):
523+
prev = data.extra_model_outputs.get("hidden_states")
524+
if prev is not None and isinstance(prev, torch.Tensor):
525+
data.extra_model_outputs["hidden_states"] = torch.cat([prev, hs], dim=0)
526+
else:
527+
data.extra_model_outputs["hidden_states"] = hs
528+
rest = {k: v for k, v in extra.items() if k != "hidden_states"}
529+
if rest:
530+
data.extra_model_outputs.update(rest)
531+
else:
532+
data.extra_model_outputs.update(extra)
533+
502534
def is_finished(self, request: SchedulerRequest, output: RequestOutput) -> bool:
503535
return request.data.req.finished()
504536

sglang_omni/engines/omni/runtime/thinker_forward.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ def thinker_forward_omni(
5050
input_deepstack_embeds=ds_input,
5151
)
5252

53+
if getattr(forward_batch, "capture_hidden_mode", None) is not None:
54+
outer_model._captured_full_hidden_states = hidden_states.clone()
55+
5356
return outer_model.logits_processor(
5457
forward_batch.input_ids,
5558
hidden_states,

sglang_omni/models/ming_omni/components/image_gen_executor.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@
2222
import torch
2323

2424
from sglang_omni.executors.interface import Executor
25-
from sglang_omni.models.ming_omni.diffusion.backend import DiffusionBackend, ImageGenParams
25+
from sglang_omni.models.ming_omni.diffusion.backend import (
26+
DiffusionBackend,
27+
ImageGenParams,
28+
)
2629
from sglang_omni.proto import StagePayload
2730

2831
logger = logging.getLogger(__name__)
@@ -91,14 +94,14 @@ def _load_models(self) -> None:
9194
skip_semantic_encoder=True,
9295
)
9396
else:
94-
self._backend.load_models(
95-
self._dit_model_path, torch.device(self._device)
96-
)
97+
self._backend.load_models(self._dit_model_path, torch.device(self._device))
9798
logger.info("[IMG_GEN] Diffusion backend loaded in %.1fs", time.time() - t0)
9899

99100
# Load thinker tokenizer for decoding output_ids → text prompt
100101
try:
101-
from sglang_omni.models.ming_omni.components.common import load_ming_tokenizer
102+
from sglang_omni.models.ming_omni.components.common import (
103+
load_ming_tokenizer,
104+
)
102105

103106
self._thinker_tokenizer = load_ming_tokenizer(self._model_path)
104107
logger.info(
@@ -276,7 +279,9 @@ def _extract_input(self, payload: StagePayload) -> tuple[str, ImageGenParams]:
276279
if isinstance(thinker_out, dict):
277280
output_ids = thinker_out.get("output_ids", [])
278281
if output_ids and self._thinker_tokenizer is not None:
279-
text = self._thinker_tokenizer.decode(output_ids, skip_special_tokens=True)
282+
text = self._thinker_tokenizer.decode(
283+
output_ids, skip_special_tokens=True
284+
)
280285

281286
# Fallback: pre-decoded text
282287
if not text:
@@ -354,7 +359,8 @@ def _try_condition_from_hidden_states(
354359
if isinstance(hidden_states, dict):
355360
# Side-channel capture: pick the last (highest) layer
356361
numeric_keys = [
357-
k for k in hidden_states
362+
k
363+
for k in hidden_states
358364
if isinstance(k, int) or (isinstance(k, str) and k.isdigit())
359365
]
360366
if not numeric_keys:

sglang_omni/models/ming_omni/components/preprocessor.py

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,17 @@ def __init__(self, model_path: str, conditioner=None):
164164
# Lazy-init image processor
165165
self._image_processor = None
166166

167+
# Image generation conditioner (optional)
168+
self._conditioner = conditioner
169+
if conditioner is not None:
170+
self._image_patch_token_id = conditioner.image_patch_token
171+
self._image_start_token_id = conditioner.image_start_token
172+
self._image_end_token_id = conditioner.image_end_token
173+
else:
174+
self._image_patch_token_id = None
175+
self._image_start_token_id = None
176+
self._image_end_token_id = None
177+
167178
def _get_image_processor(self):
168179
"""Lazy-init Qwen2VLImageProcessor (same processor as Ming-Omni uses)."""
169180
if self._image_processor is None:
@@ -199,17 +210,6 @@ def _process_images(
199210
)
200211
return pixel_values, image_grid_thw, token_counts
201212

202-
# Image generation conditioner (optional)
203-
self._conditioner = conditioner
204-
if conditioner is not None:
205-
self._image_patch_token_id = conditioner.image_patch_token
206-
self._image_start_token_id = conditioner.image_start_token
207-
self._image_end_token_id = conditioner.image_end_token
208-
else:
209-
self._image_patch_token_id = None
210-
self._image_start_token_id = None
211-
self._image_end_token_id = None
212-
213213
async def __call__(self, payload: StagePayload) -> StagePayload:
214214
"""Process a chat completion request into pipeline state."""
215215
request = payload.request
@@ -340,26 +340,20 @@ async def __call__(self, payload: StagePayload) -> StagePayload:
340340
gen_mask = None
341341

342342
if is_image_gen:
343-
num_query_tokens = sum(
344-
s * s for s in self._conditioner.img_gen_scales
345-
)
343+
num_query_tokens = sum(s * s for s in self._conditioner.img_gen_scales)
346344

347345
suffix_ids = (
348346
[self._image_start_token_id]
349347
+ [self._image_patch_token_id] * num_query_tokens
350348
+ [self._image_end_token_id]
351349
)
352350
suffix_tensor = torch.tensor([suffix_ids], dtype=torch.long)
353-
input_ids_tensor = torch.cat(
354-
[input_ids_tensor, suffix_tensor], dim=1
355-
)
351+
input_ids_tensor = torch.cat([input_ids_tensor, suffix_tensor], dim=1)
356352
attention_mask = torch.ones_like(input_ids_tensor)
357353

358354
# gen_mask: 0 for text, 1 for query tokens, 0 for start/end markers
359355
text_len = len(input_ids)
360-
gen_mask = torch.zeros(
361-
input_ids_tensor.shape[1], dtype=torch.long
362-
)
356+
gen_mask = torch.zeros(input_ids_tensor.shape[1], dtype=torch.long)
363357
gen_mask[text_len + 1 : text_len + 1 + num_query_tokens] = 1
364358

365359
prompt: PromptInputs = {

sglang_omni/models/ming_omni/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
AGGREGATE_STAGE,
1717
AUDIO_STAGE,
1818
DECODE_STAGE,
19-
IMAGE_STAGE,
2019
IMAGE_GEN_STAGE,
20+
IMAGE_STAGE,
2121
PREPROCESSING_STAGE,
2222
TALKER_STAGE,
2323
THINKER_STAGE,

sglang_omni/models/ming_omni/diffusion/backend.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
from abc import ABC, abstractmethod
77
from dataclasses import dataclass
8-
from typing import Any
98

109
import torch
1110
from PIL import Image
@@ -51,4 +50,3 @@ def generate(
5150

5251
def unload(self) -> None:
5352
"""Release GPU memory."""
54-
pass

sglang_omni/models/ming_omni/diffusion/bailing_moe_config.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
"""Bailing MoE model configuration"""
44
from transformers.configuration_utils import PretrainedConfig
55

6+
67
class BailingMoeV2Config(PretrainedConfig):
78
model_type = "bailing_moe_v2"
9+
810
def __init__(
911
self,
1012
vocab_size=30592,
@@ -84,5 +86,7 @@ def __init__(
8486
self.partial_rotary_factor = partial_rotary_factor
8587
self.router_type = router_type
8688
self.use_interleaved_frame_timestamp = use_interleaved_frame_timestamp
87-
super().__init__(pad_token_id=pad_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs)
89+
super().__init__(
90+
pad_token_id=pad_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs
91+
)
8892
self._attn_implementation = _attn_implementation

0 commit comments

Comments
 (0)