Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions python/sglang/srt/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,10 +339,14 @@ class ChatCompletionMessageContentTextPart(BaseModel):
class ChatCompletionMessageContentImageURL(BaseModel):
url: str
detail: Optional[Literal["auto", "low", "high"]] = "auto"
max_dynamic_patch: Optional[int] = None
min_dynamic_patch: Optional[int] = None


class ChatCompletionMessageContentVideoURL(BaseModel):
url: str
max_dynamic_patch: Optional[int] = None
min_dynamic_patch: Optional[int] = None


class ChatCompletionMessageContentAudioURL(BaseModel):
Expand Down Expand Up @@ -516,6 +520,10 @@ class ChatCompletionRequest(BaseModel):
stream_reasoning: bool = True
chat_template_kwargs: Optional[Dict] = None

# SGLang multimodal tiling controls (extensions)
max_dynamic_patch: Optional[int] = None
min_dynamic_patch: Optional[int] = None

# Custom logit processor for advanced sampling control
custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None
custom_params: Optional[Dict] = None
Expand Down
32 changes: 32 additions & 0 deletions python/sglang/srt/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,32 @@
logger = logging.getLogger(__name__)


def _extract_max_dynamic_patch(request: ChatCompletionRequest):
img_vals = []
vid_vals = []
for msg in request.messages or []:
content = getattr(msg, "content", None)
if not isinstance(content, list):
continue
for part in content:
# pydantic object or dict type
if getattr(part, "type", None) == "image_url":
iu = getattr(part, "image_url", None)
mdp = getattr(iu, "max_dynamic_patch", None) if iu else None
if mdp is not None:
img_vals.append(int(mdp))
elif getattr(part, "type", None) == "video_url":
vu = getattr(part, "video_url", None)
mdp = getattr(vu, "max_dynamic_patch", None) if vu else None
if mdp is not None:
vid_vals.append(int(mdp))

# TODO(yuan-luo): per-item max_dynamic_patch for both image and video
img_max_dynamic_patch = min(img_vals) if img_vals else None
vid_max_dynamic_patch = min(vid_vals) if vid_vals else None
return img_max_dynamic_patch, vid_max_dynamic_patch


class OpenAIServingChat(OpenAIServingBase):
"""Handler for /v1/chat/completions requests"""

Expand Down Expand Up @@ -195,6 +221,9 @@ def _convert_to_internal_request(
if first_adapter:
self._validate_lora_enabled(first_adapter)

img_max_dynamic_patch, vid_max_dynamic_patch = _extract_max_dynamic_patch(
request
)
adapted_request = GenerateReqInput(
**prompt_kwargs,
image_data=processed_messages.image_data,
Expand All @@ -219,6 +248,9 @@ def _convert_to_internal_request(
priority=request.priority,
custom_labels=custom_labels,
custom_logit_processor=request.custom_logit_processor,
image_max_dynamic_patch=img_max_dynamic_patch,
video_max_dynamic_patch=vid_max_dynamic_patch,
max_dynamic_patch=getattr(request, "max_dynamic_patch", None),
)

return adapted_request, request
Expand Down
6 changes: 6 additions & 0 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,12 @@ class GenerateReqInput(BaseReq, APIServingTimingMixin):
need_wait_for_image: Optional[bool] = None
num_items_assigned: Optional[List] = None

# Multimodal tiling controls (extensions)
max_dynamic_patch: Optional[int] = None
min_dynamic_patch: Optional[int] = None
image_max_dynamic_patch: Optional[int] = None
video_max_dynamic_patch: Optional[int] = None

def contains_mm_input(self) -> bool:
return (
has_valid_data(self.image_data)
Expand Down
36 changes: 32 additions & 4 deletions python/sglang/srt/multimodal/processors/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class InternVLProcessor(BaseMultimodalProcessor):

IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
IMAGE_MAX_NUM = 12

DEFAULT_VIDEO_NUM_FRAMES = 32
VIDEO_MAX_NUM = 1
Expand Down Expand Up @@ -86,6 +87,11 @@ def __init__(self, hf_config, server_args, _image_processor, *args, **kwargs):
else None
)

self.image_token_id = (
tokenizer.convert_tokens_to_ids(self.IMG_CONTEXT)
if self.IMG_CONTEXT
else None
)
self.num_image_token = int(
(image_size // patch_size) ** 2 * (hf_config.downsample_ratio**2)
)
Expand All @@ -97,7 +103,7 @@ def __init__(self, hf_config, server_args, _image_processor, *args, **kwargs):
# Offset token id use IMG_CONTEXT / VIDEO_CONTEXT
self.mm_tokens = MultimodalSpecialTokens(
image_token=self.IMAGE_PLACEHOLDER_TOKEN,
image_token_id=tokenizer.convert_tokens_to_ids(self.IMG_CONTEXT),
image_token_id=self.image_token_id,
video_token=self.VIDEO_PLACEHOLDER_TOKEN,
video_token_id=self.video_token_id,
).build(_image_processor)
Expand All @@ -122,7 +128,9 @@ def __init__(self, hf_config, server_args, _image_processor, *args, **kwargs):
)

@staticmethod
def dynamic_preprocess(tensor, image_size=448, max_num=12, use_thumbnail=False):
def dynamic_preprocess(
tensor, image_size=448, max_num=IMAGE_MAX_NUM, use_thumbnail=False
):
# Tensor: (C,H,W) float on GPU
C, H, W = tensor.shape
aspect_ratio = W / H
Expand Down Expand Up @@ -264,6 +272,25 @@ async def process_mm_data_async(
async def process_qwen_mm_data_async(
self, image_data, input_text, request_obj, **kwargs
):

img_max_num = (
getattr(request_obj, "image_max_dynamic_patch", None)
or getattr(request_obj, "max_dynamic_patch", None)
or kwargs.get("image_max_dynamic_patch")
or kwargs.get("max_dynamic_patch")
or self.IMAGE_MAX_NUM
)
img_max_num = max(1, int(img_max_num))

vid_max_num = (
getattr(request_obj, "video_max_dynamic_patch", None)
or getattr(request_obj, "max_dynamic_patch", None)
or kwargs.get("video_max_dynamic_patch")
or kwargs.get("max_dynamic_patch")
or self.VIDEO_MAX_NUM
)
vid_max_num = max(1, int(vid_max_num))

# Qwen/Qwen3 branch: OpenAI-style placeholders <image>/<video>
prompt = input_text or ""
video_data = getattr(request_obj, "video_data", None) or []
Expand Down Expand Up @@ -314,7 +341,7 @@ async def process_qwen_mm_data_async(

tensor = (tensor - mean) / std
tiles = self.dynamic_preprocess(
tensor, image_size=448, max_num=12, use_thumbnail=True
tensor, image_size=448, max_num=img_max_num, use_thumbnail=True
)
pixel_values_list.append(tiles)
num_patches_list.append(int(tiles.shape[0]))
Expand Down Expand Up @@ -374,7 +401,7 @@ async def process_qwen_mm_data_async(
tiles = self.dynamic_preprocess(
frame_t,
image_size=448,
max_num=self.VIDEO_MAX_NUM,
max_num=vid_max_num,
use_thumbnail=self.VIDEO_USE_THUMBNAIL,
)
per_video_tiles.append(tiles)
Expand All @@ -394,6 +421,7 @@ async def process_qwen_mm_data_async(

input_text_mid = base_output.input_text or prompt
input_text_mid = input_text_mid.replace(self.IMAGE_PLACEHOLDER_TOKEN, img_ph)
input_text_mid = input_text_mid.replace(self.IMG_CONTEXT, img_ph)

if self.VIDEO_CONTEXT_TOKEN and self.video_token_id is not None:
input_text_mid = input_text_mid.replace(
Expand Down
14 changes: 13 additions & 1 deletion python/sglang/srt/multimodal/processors/qwen_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,10 +305,22 @@ async def process_mm_data_async(
**kwargs,
):
entry_time = time.perf_counter()

video_data = getattr(request_obj, "video_data", None) or []
video_cfgs = []
video_data_norm = []
for v in video_data:
if isinstance(v, dict):
video_data_norm.append(v.get("url") or v.get("path"))
video_cfgs.append(v)
else:
video_data_norm.append(v)
video_cfgs.append(None)

base_output = self.load_mm_data(
prompt=input_text,
image_data=image_data,
video_data=request_obj.video_data,
video_data=video_data_norm,
audio_data=request_obj.audio_data,
multimodal_tokens=self.mm_tokens,
)
Expand Down
22 changes: 19 additions & 3 deletions python/sglang/srt/parser/jinja_template_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,18 +154,34 @@ def process_content_for_template_format(
chunk_type = chunk.get("type")

if chunk_type == "image_url":
image_obj = chunk.get("image_url") or {}
mdp = image_obj.get("max_dynamic_patch", None)
# Also allow flat style: chunk["max_dynamic_patch"]
if mdp is None:
mdp = chunk.get("max_dynamic_patch", None)
image_data.append(
ImageData(
url=chunk["image_url"]["url"],
detail=chunk["image_url"].get("detail", "auto"),
url=image_obj["url"],
detail=image_obj.get("detail", "auto"),
max_dynamic_patch=mdp,
)
)
if chunk.get("modalities"):
modalities.append(chunk.get("modalities"))
# Normalize to simple 'image' type for template compatibility
processed_content_parts.append({"type": "image"})
elif chunk_type == "video_url":
video_data.append(chunk["video_url"]["url"])
video_obj = chunk.get("video_url") or {}
mdp = video_obj.get("max_dynamic_patch", None)
if mdp is None:
mdp = chunk.get("max_dynamic_patch", None)
# Keep structured info for backend, but template only sees {"type":"video"}
video_data.append(
{
"url": video_obj["url"],
"max_dynamic_patch": mdp,
}
)
if chunk.get("modalities"):
modalities.append(chunk.get("modalities"))
# Normalize to simple 'video' type for template compatibility
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,6 +838,7 @@ def load_audio(
class ImageData:
url: str
detail: Optional[Literal["auto", "low", "high"]] = "auto"
max_dynamic_patch: Optional[int] = None


def load_image(
Expand Down