Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ async def generate(

try:
time_start = time.perf_counter()
for idx in range(len(request.multimodal_inputs)):
multimodal_inputs = request.multimodal_inputs or []
for idx in range(len(multimodal_inputs)):
if not request.multimodal_inputs[idx].multimodal_input.image_url:
raise ValueError("image_url is required for the encode worker.")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import copy
import logging
import os
from collections import defaultdict

import safetensors
import torch
Expand Down Expand Up @@ -80,8 +79,18 @@ async def generate(self, request: vLLMMultimodalRequest, context):
# values prevent incorrect prefix cache matches between different images.
multi_modal_data = None
if is_qwen_vl_model(self.config.model):
# Extract image_grid_thw and embeddings_shape from multimodal_inputs
image_grid_thw = []
embeddings_shape = None
for mi in (request.multimodal_inputs or []):
if mi.image_grid_thw:
image_grid_thw.extend(mi.image_grid_thw)
if mi.embeddings_shape and embeddings_shape is None:
embeddings_shape = mi.embeddings_shape
multi_modal_data = construct_qwen_decode_mm_data(
request.image_grid_thw, request.embeddings_shape, request.request_id
image_grid_thw if image_grid_thw else None,
embeddings_shape,
request.request_id
)

gen = self.engine_client.generate(
Expand Down Expand Up @@ -169,8 +178,8 @@ async def generate(self, request: vLLMMultimodalRequest, context):
request = vLLMMultimodalRequest.model_validate(request)
logger.debug(f"Received PD request: {{ id: {request.request_id} }}.")

multi_modal_data = defaultdict(list)
for mi in request.multimodal_inputs:
multi_modal_data = {}
for mi in (request.multimodal_inputs or []):
# ECConnector consumer mode: vLLM loads embeddings automatically from disk
# We need to pass multimodal_input so vLLM can generate mm_hash and look up cache
if self.config.ec_consumer_mode:
Expand All @@ -181,13 +190,17 @@ async def generate(self, request: vLLMMultimodalRequest, context):
# Use PIL image loading - vLLM will detect it's already in EC cache
# and load from disk instead of reprocessing
if mi.multimodal_input.image_url:
if "image" not in multi_modal_data:
multi_modal_data["image"] = []
multi_modal_data["image"].append(
await self.image_loader.load_image(
mi.multimodal_input.image_url
)
)
elif mi.multimodal_input.video_url:
# For video, load as image placeholder (vLLM will use EC cache)
if "image" not in multi_modal_data:
multi_modal_data["image"] = []
multi_modal_data["image"].append(
await self.image_loader.load_image(
request.multimodal_input.video_url
Expand Down Expand Up @@ -232,6 +245,8 @@ async def generate(self, request: vLLMMultimodalRequest, context):
self.EMBEDDINGS_DTYPE,
video_numpy=video_numpy,
)
if "video" not in multi_modal_data:
multi_modal_data["video"] = []
multi_modal_data["video"].append(mm_data["video"])
else:
mm_data = construct_mm_data(
Expand All @@ -241,7 +256,7 @@ async def generate(self, request: vLLMMultimodalRequest, context):
image_grid_thw=mi.image_grid_thw,
)
if isinstance(mm_data["image"], dict):
if multi_modal_data["image"] == []:
if "image" not in multi_modal_data:
multi_modal_data["image"] = mm_data["image"]
else:
# [gluo FIXME] need to understand how Qwen consumes multi-image embeddings
Expand All @@ -261,22 +276,30 @@ async def generate(self, request: vLLMMultimodalRequest, context):
else:
logger.info(f"Get embedding of shape {mm_data['image'].shape}")
# [gluo FIXME] embedding with multiple images?
if multi_modal_data["image"] == []:
if "image" not in multi_modal_data:
multi_modal_data["image"] = mm_data["image"]
else:
multi_modal_data["image"] = torch.cat(
(multi_modal_data["image"], mm_data["image"])
)
else:
# Use PIL image instead of image embeddings
if "image" not in multi_modal_data:
multi_modal_data["image"] = []
multi_modal_data["image"].append(
await self.image_loader.load_image(mi.multimodal_input.image_url)
)

# Remove the image features from the request as they are not required
request.multimodal_inputs = None
# Clear heavy data from multimodal_inputs but keep metadata (image_grid_thw, embeddings_shape)
# needed by Decode Worker for Qwen VL models
for mi in (request.multimodal_inputs or []):
if mi.multimodal_input:
mi.multimodal_input.image_url = None
mi.multimodal_input.video_url = None
mi.serialized_request = None

logger.info(f"Prepared multimodal data size: {len(multi_modal_data['image'])}")
mm_size = len(multi_modal_data.get('image', [])) if not isinstance(multi_modal_data.get('image'), dict) else 1
logger.info(f"Prepared multimodal data size: {mm_size}")
logger.info(f"{multi_modal_data}")

# Deepcopy the request to avoid modifying the original
Expand Down
18 changes: 17 additions & 1 deletion components/src/dynamo/vllm/multimodal_utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,25 @@ def normalize_model_name(model_name: str) -> str:
return f"{org}/{model}"

# Handle local directory paths - extract the last directory name
# and try to infer organization from known model patterns
path = Path(model_name)
if path.exists() and path.is_dir():
return path.name
name = path.name
# Try to infer org from known model patterns
if "Qwen" in name:
return f"Qwen/{name}"
if "llava" in name.lower():
return f"llava-hf/{name}"
return name

# If path doesn't exist but looks like a local path, still try to extract name
if model_name.startswith("/"):
name = Path(model_name).name
if "Qwen" in name:
return f"Qwen/{name}"
if "llava" in name.lower():
return f"llava-hf/{name}"
return name

# If no pattern matches, return the original name
return model_name
Expand Down
7 changes: 5 additions & 2 deletions components/src/dynamo/vllm/multimodal_utils/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,10 @@ def parse_sampling_params(cls, v: Any) -> SamplingParams:
if isinstance(v, str):
v = json.loads(v)
if isinstance(v, dict):
return SamplingParams(**v)
# Filter out private/internal fields that shouldn't be passed to constructor
# These fields are serialized but cause issues when deserialized (e.g., sets become lists)
filtered = {k: val for k, val in v.items() if not k.startswith("_")}
return SamplingParams(**filtered)
return v

@field_serializer("sampling_params")
Expand Down Expand Up @@ -156,7 +159,7 @@ class MultiModalGroup(BaseModel):

class vLLMMultimodalRequest(vLLMGenerateRequest):
model_config = ConfigDict(arbitrary_types_allowed=True)
multimodal_inputs: List[MultiModalGroup] = Field(default_factory=list)
multimodal_inputs: Optional[List[MultiModalGroup]] = Field(default_factory=list)


class VLLMNativeEncoderRequest(BaseModel):
Expand Down
Loading