Skip to content

Commit 310246d

Browse files
janhilgardclaude
andcommitted
fix: MLLM image processing — exclude_none for Jinja template, error handling
- Use model_dump(exclude_none=True) for MLLM messages: Qwen3VL Jinja template checks 'image_url' in item — null keys from Pydantic model_dump() falsely triggered extra <|image_pad|> tokens, causing "index out of bounds" crash in processor - Add per-request error handling in MLLM batch preprocessing: failed requests now get immediate finish_reason="error" instead of infinite retry loop (was retrying 5756 times in 300s before timeout) - Handle error responses in MLLM scheduler to properly clean up and return error status to client Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent bf78650 commit 310246d

File tree

3 files changed

+68
-6
lines changed

3 files changed

+68
-6
lines changed

vllm_mlx/mllm_batch_generator.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,9 @@ def __init__(
348348
# Statistics
349349
self._stats = MLLMBatchStats()
350350

351+
# Error responses for requests that failed during preprocessing
352+
self._pending_error_responses: List[MLLMBatchResponse] = []
353+
351354
# Vision embedding cache for repeated images
352355
self.vision_cache = VisionEmbeddingCache(
353356
max_pixel_entries=vision_cache_size,
@@ -621,9 +624,35 @@ def _process_prompts(self, requests: List[MLLMBatchRequest]) -> MLLMBatch:
621624

622625
tic = time.perf_counter()
623626

624-
# Preprocess all requests
627+
# Preprocess all requests (per-request error handling)
628+
failed_requests = []
625629
for req in requests:
626-
self._preprocess_request(req)
630+
try:
631+
self._preprocess_request(req)
632+
except Exception as e:
633+
logger.error(
634+
f"Failed to preprocess request {req.request_id}: "
635+
f"{type(e).__name__}: {e}"
636+
)
637+
failed_requests.append(req)
638+
639+
# Remove failed requests from batch and create error responses
640+
if failed_requests:
641+
for req in failed_requests:
642+
requests.remove(req)
643+
self._pending_error_responses.append(
644+
MLLMBatchResponse(
645+
uid=req.uid,
646+
request_id=req.request_id,
647+
token=0,
648+
logprobs=mx.zeros(1),
649+
finish_reason="error",
650+
)
651+
)
652+
653+
if not requests:
654+
# All requests failed
655+
return None
627656

628657
total_prompt_tokens = sum(
629658
req.input_ids.size if req.input_ids is not None else 1 for req in requests
@@ -768,10 +797,16 @@ def _next(self) -> List[MLLMBatchResponse]:
768797
self.active_batch = new_batch
769798
prompt_processing = True
770799

800+
# Collect any pending error responses (from failed preprocessing)
801+
error_responses = []
802+
if self._pending_error_responses:
803+
error_responses = list(self._pending_error_responses)
804+
self._pending_error_responses.clear()
805+
771806
# Generate next token for active batch
772807
batch = self.active_batch
773808
if batch is None:
774-
return []
809+
return error_responses
775810

776811
y, logprobs = batch.y, batch.logprobs
777812
batch.y, batch.logprobs = self._step(y[:, None], batch.cache)
@@ -840,7 +875,7 @@ def _next(self) -> List[MLLMBatchResponse]:
840875
self.active_batch = None
841876

842877
self._stats.generation_tokens += len(responses)
843-
return responses
878+
return error_responses + responses
844879

845880
def next(self) -> List[MLLMBatchResponse]:
846881
"""

vllm_mlx/mllm_scheduler.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,27 @@ def _process_batch_responses(
446446
if request is None:
447447
continue
448448

449+
# Handle error responses from failed preprocessing
450+
if response.finish_reason == "error":
451+
output = RequestOutput(
452+
request_id=request_id,
453+
new_token_ids=[],
454+
new_text="",
455+
output_token_ids=[],
456+
prompt_tokens=0,
457+
completion_tokens=0,
458+
finished=True,
459+
finish_reason="error",
460+
)
461+
request.status = RequestStatus.FINISHED_ABORTED
462+
request.output_text = ""
463+
request.finish_reason = "error"
464+
finished_ids.add(request_id)
465+
self.num_requests_processed += 1
466+
logger.warning(f"Request {request_id} failed during preprocessing")
467+
outputs.append(output)
468+
continue
469+
449470
# Append token to request
450471
request.output_tokens.append(response.token)
451472
request.num_output_tokens = len(request.output_tokens)

vllm_mlx/server.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1385,10 +1385,16 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re
13851385
# For MLLM models, keep original messages with embedded images
13861386
# (MLLM.chat() extracts images from message content internally)
13871387
if engine.is_mllm:
1388-
# Convert Pydantic messages to dicts preserving full content
1388+
# Convert Pydantic messages to dicts preserving full content.
1389+
# exclude_none=True is critical: Qwen3VL Jinja template checks
1390+
# 'image_url' in item — null keys would falsely trigger image tokens.
13891391
messages = []
13901392
for msg in request.messages:
1391-
msg_dict = msg.model_dump() if hasattr(msg, "model_dump") else dict(msg)
1393+
msg_dict = (
1394+
msg.model_dump(exclude_none=True)
1395+
if hasattr(msg, "model_dump")
1396+
else {k: v for k, v in dict(msg).items() if v is not None}
1397+
)
13921398
messages.append(msg_dict)
13931399
images, videos = [], [] # MLLM extracts these from messages
13941400
logger.debug(f"MLLM: Processing {len(messages)} messages")

0 commit comments

Comments
 (0)