Skip to content

Commit d77b84d

Browse files
Checkpoint
Signed-off-by: Divyansh Singhvi <divyanshsinghvi@gmail.com>
1 parent 4d18960 commit d77b84d

File tree

17 files changed

+1977
-402
lines changed

17 files changed

+1977
-402
lines changed

tests/e2e/offline_inference/stage_configs/qwen3_omni_ci.yaml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@ stage_args:
2323
hf_config_name: thinker_config
2424
tensor_parallel_size: 2
2525
load_format: dummy
26-
final_output: true
27-
final_output_type: text
2826
is_comprehension: true
2927
default_sampling_params:
3028
temperature: 0.4
@@ -55,8 +53,8 @@ stage_args:
5553
load_format: dummy
5654
engine_input_source: [0]
5755
custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker
58-
# final_output: true
59-
# final_output_type: text
56+
final_output: true
57+
final_output_type: text
6058
default_sampling_params:
6159
temperature: 0.9
6260
top_k: 50

vllm_omni/core/sched/omni_generation_scheduler.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import time
22
from collections import defaultdict
33

4+
from vllm.logger import init_logger
45
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
56
from vllm.v1.core.sched.request_queue import create_request_queue
67
from vllm.v1.core.sched.scheduler import (
@@ -16,6 +17,8 @@
1617
from vllm_omni.core.sched.output import OmniNewRequestData
1718
from vllm_omni.outputs import OmniModelRunnerOutput
1819

20+
logger = init_logger(__name__)
21+
1922

2023
class OmniGenerationScheduler(VLLMScheduler):
2124
def schedule(self) -> SchedulerOutput:
@@ -185,7 +188,7 @@ def update_from_output(
185188
# request is aborted while the model is executing it (e.g.,
186189
# in pipeline parallelism).
187190
continue
188-
191+
logger.info(f"Diffusion request completed: {req_id} {model_runner_output.req_id_to_index}")
189192
req_index = model_runner_output.req_id_to_index[req_id]
190193
generated_token_ids = sampled_token_ids[req_index] if sampled_token_ids else []
191194

@@ -248,9 +251,14 @@ def update_from_output(
248251
# Convert pooler_output tensor to dict format expected by OmniEngineCoreOutput
249252
pooling_output_dict = None
250253
if pooler_output is not None:
251-
# Wrap tensor in dict to match OmniEngineCoreOutput.pooling_output type
252-
# which expects Optional[dict[str, torch.Tensor]]
253-
pooling_output_dict = {"model_outputs": pooler_output}
254+
# If pooler_output is already a dict (from stages that output multiple tensors),
255+
# preserve it directly. Otherwise, wrap tensor in dict.
256+
if isinstance(pooler_output, dict):
257+
logger.info(f"[DEBUG scheduler] pooler_output is dict with keys: {list(pooler_output.keys())}")
258+
pooling_output_dict = pooler_output
259+
else:
260+
logger.info("[DEBUG scheduler] pooler_output is tensor, wrapping as model_outputs")
261+
pooling_output_dict = {"model_outputs": pooler_output}
254262
if new_token_ids or pooler_output is not None or kv_transfer_params:
255263
# Add EngineCoreOutput for this Request.
256264
outputs[request.client_index].append(

vllm_omni/engine/output_processor.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,10 @@ def _to_cpu(x):
111111
return x
112112

113113
if isinstance(payload, dict):
114+
logger.info(
115+
f"[DEBUG add_multimodal_tensor] payload is dict with keys: {list(payload.keys())}, "
116+
f"mm_type={self.mm_type}"
117+
)
114118
incoming: Dict[str, Any] = {}
115119
# Optional remap: if producer used "model_outputs" or "hidden", rename to mm_type
116120
# to keep a consistent key namespace per engine_core_output_type.
@@ -125,8 +129,10 @@ def _to_cpu(x):
125129
incoming[k] = {str(sk): _to_cpu(sv) for sk, sv in v.items()}
126130
else:
127131
incoming[k] = _to_cpu(v)
132+
logger.info(f"[DEBUG add_multimodal_tensor] incoming dict has keys: {list(incoming.keys())}")
128133
else:
129134
key = self.mm_type or "hidden"
135+
logger.info(f"[DEBUG add_multimodal_tensor] payload is NOT dict, wrapping as: {{'{key}': payload}}")
130136
incoming = {key: _to_cpu(payload)}
131137

132138
if self.mm_accumulated is None:
@@ -380,6 +386,14 @@ def process_outputs(
380386
# 2.5) Accumulate multimodal tensors in RequestState
381387
try:
382388
mm_type = (getattr(eco, "output_type", self.engine_core_output_type) or "").lower()
389+
logger.info(
390+
f"[DEBUG process_outputs] req_id={req_id}, mm_type={mm_type},"
391+
f"pooling_output type: {type(pooling_output)}"
392+
)
393+
if isinstance(pooling_output, dict):
394+
logger.info(
395+
f"[DEBUG process_outputs] pooling_output is dictwith keys: {list(pooling_output.keys())}"
396+
)
383397
if pooling_output is not None and isinstance(req_state, OmniRequestState):
384398
req_state.add_multimodal_tensor(pooling_output, mm_type)
385399
except Exception:
@@ -497,9 +511,27 @@ def _process_text_image_output(self, eco: EngineCoreOutput) -> None:
497511
def _process_latents_output(self, eco: EngineCoreOutput) -> None:
498512
"""Ensure latent tensors are surfaced via pooling_output."""
499513
if eco.pooling_output is None:
514+
# DEBUG: Log what we're processing
515+
mm = getattr(eco, "multimodal_outputs", None)
516+
logger.info(f"[DEBUG _process_latents_output] multimodal_outputs type: {type(mm)}")
517+
if isinstance(mm, dict):
518+
logger.info(f"[DEBUG _process_latents_output] multimodal_outputs keys: {list(mm.keys())}")
519+
500520
tensor = self._extract_from_multimodal_outputs(eco, keys=("latent", "latents", "z", "posterior"))
521+
logger.info(
522+
f"[DEBUG _process_latents_output] extracted tensor type: {type(tensor)}, "
523+
f"is dict: {isinstance(tensor, dict)}"
524+
)
501525
if tensor is not None:
502526
eco.pooling_output = tensor
527+
logger.info("[DEBUG _process_latents_output] set eco.pooling_output to extracted tensor")
528+
else:
529+
# pooling_output already set (likely from scheduler with full dict)
530+
logger.info(f"[DEBUG _process_latents_output] pooling_output already set, type: {type(eco.pooling_output)}")
531+
if isinstance(eco.pooling_output, dict):
532+
logger.info(
533+
f"[DEBUG _process_latents_output] pooling_output dict keys: {list(eco.pooling_output.keys())}"
534+
)
503535

504536
def _process_audio_output(self, eco: EngineCoreOutput) -> None:
505537
"""Ensure audio tensors are surfaced via pooling_output."""
@@ -532,9 +564,15 @@ def _extract_from_multimodal_outputs(self, eco: EngineCoreOutput, keys: tuple[st
532564
for k in keys:
533565
v = mm.get(k)
534566
if isinstance(v, torch.Tensor):
567+
logger.info(f"[DEBUG _extract_from_multimodal_outputs] Found key '{k}' in multimodal_outputs")
535568
return v
536569
# Try the first tensor in the dict as a fallback
537-
for v in mm.values():
570+
logger.info(f"[DEBUG _extract_from_multimodal_outputs] No matching keys {keys}, using fallback (first tensor)")
571+
for k, v in mm.items():
538572
if isinstance(v, torch.Tensor):
573+
logger.info(
574+
f"[DEBUG _extract_from_multimodal_outputs] Fallback: extracted first"
575+
f" tensor with key '{k}', shape: {v.shape}"
576+
)
539577
return v
540578
return None

vllm_omni/entrypoints/omni_llm.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -351,10 +351,7 @@ def _run_generation(
351351
req_id = result.get("request_id")
352352
if "error" in result:
353353
logger.error(
354-
"Stage %s error on request %s: %s",
355-
stage_id,
356-
req_id,
357-
result["error"],
354+
"Stage %s error on request %s: %s %s", stage_id, req_id, result["error"], result["error_tb"]
358355
)
359356
continue
360357

vllm_omni/entrypoints/omni_stage.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import multiprocessing as mp
1717
import os
1818
import sys
19+
import traceback as _traceback
1920
from typing import Any
2021

2122
from vllm.inputs import TextPrompt
@@ -720,12 +721,14 @@ def filter(self, record: _logging.LogRecord) -> bool:
720721
)
721722
except Exception as e:
722723
_logging.getLogger(__name__).exception("[Stage-%s] Failed on batch %s: %s", stage_id, batch_request_ids, e)
724+
_tb = _traceback.format_exc()
723725
for rid in batch_request_ids:
724726
out_q.put(
725727
{
726728
"request_id": rid,
727729
"stage_id": stage_id,
728730
"error": str(e),
731+
"error_tb": _tb,
729732
}
730733
)
731734

0 commit comments

Comments
 (0)