Skip to content

Commit 022012a

Browse files
authored
Support Phi-4 Multi-Modal (text + vision only) (#6494)
1 parent 681e7af commit 022012a

File tree

8 files changed

+650
-6
lines changed

8 files changed

+650
-6
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,5 +228,8 @@ compile_commands.json
228228

229229
1
230230

231+
# Autoenv
232+
.env.leave
233+
231234
# Rust lib
232235
Cargo.lock

python/pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ dependencies = ["aiohttp", "requests", "tqdm", "numpy", "IPython", "setproctitle
1717

1818
[project.optional-dependencies]
1919
runtime_common = [
20+
"blobfile==3.0.0",
2021
"compressed-tensors",
2122
"datasets",
2223
"fastapi",
@@ -38,12 +39,12 @@ runtime_common = [
3839
"python-multipart",
3940
"pyzmq>=25.1.2",
4041
"soundfile==0.13.1",
42+
"scipy",
4143
"torchao==0.9.0",
4244
"transformers==4.51.1",
4345
"uvicorn",
4446
"uvloop",
4547
"xgrammar==0.1.19",
46-
"blobfile==3.0.0"
4748
]
4849

4950
srt = [

python/sglang/srt/configs/model_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,7 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
552552
"Qwen2_5_VLForConditionalGeneration",
553553
"KimiVLForConditionalGeneration",
554554
"InternVLChatModel",
555+
"Phi4MMForCausalLM",
555556
]
556557

557558

python/sglang/srt/conversation.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -661,6 +661,20 @@ def generate_chat_conv(
661661
)
662662
)
663663

664+
# TODO (lifuhuang): Refactor BaseMultimodalProcessor to support the default image token "<|image_{index}|>" in the future.
665+
register_conv_template(
666+
Conversation(
667+
name="phi-4-mm",
668+
system_message="You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.",
669+
system_template="<|system|>{system_message}<|end|>",
670+
roles=("<|user|>", "<|assistant|>"),
671+
sep_style=SeparatorStyle.NO_COLON_SINGLE,
672+
sep="<|end|>",
673+
stop_str="<|end|>",
674+
image_token="<|endoftext10|>",
675+
)
676+
)
677+
664678
register_conv_template(
665679
Conversation(
666680
name="chatml",
@@ -945,3 +959,9 @@ def match_openbmb_minicpm(model_path: str):
945959
def match_moonshot_kimivl(model_path: str):
946960
if re.search(r"kimi.*vl", model_path, re.IGNORECASE):
947961
return "kimi-vl"
962+
963+
964+
@register_conv_template_matching_function
965+
def match_phi_4_mm(model_path: str):
966+
if "phi-4-multimodal" in model_path.lower():
967+
return "phi-4-mm"
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import logging
2+
from typing import List, Union
3+
4+
from sglang.srt.managers.multimodal_processors.base_processor import (
5+
BaseMultimodalProcessor,
6+
MultimodalSpecialTokens,
7+
)
8+
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
9+
from sglang.srt.models.phi4mmvllm import Phi4MMForCausalLM
10+
11+
logger = logging.getLogger(__name__)
12+
13+
_IMAGE_SPECIAL_TOKEN = "<|endoftext10|>"
14+
_IMAGE_SPECIAL_TOKEN_ID = 200010
15+
16+
17+
class Phi4MMImageProcessor(BaseMultimodalProcessor):
18+
models = [Phi4MMForCausalLM]
19+
20+
def __init__(self, hf_config, server_args, _processor):
21+
super().__init__(hf_config, server_args, _processor)
22+
self.multimodal_tokens = MultimodalSpecialTokens(
23+
image_token=_IMAGE_SPECIAL_TOKEN,
24+
)
25+
26+
async def process_mm_data_async(
27+
self,
28+
image_data: List[Union[str, bytes]],
29+
input_text,
30+
request_obj,
31+
max_req_input_len,
32+
**kwargs,
33+
):
34+
audio_data = request_obj.audio_data
35+
36+
if not image_data and not audio_data:
37+
return None
38+
39+
if not isinstance(image_data, list):
40+
image_data = [image_data]
41+
42+
if not isinstance(audio_data, list):
43+
audio_data = [audio_data]
44+
45+
if audio_data:
46+
logger.warning(
47+
"Currently SGLang does not support audio data for Phi4MM. We are working on it. You can file an issue to help us prioritize."
48+
)
49+
audio_data = []
50+
51+
base_output = self.load_mm_data(
52+
prompt=input_text,
53+
max_req_input_len=max_req_input_len,
54+
audio_data=audio_data,
55+
image_data=image_data,
56+
multimodal_tokens=self.multimodal_tokens,
57+
)
58+
if base_output is None:
59+
return None
60+
61+
res = self.process_mm_data(
62+
input_text=base_output.input_text,
63+
images=base_output.images,
64+
audios=base_output.audios,
65+
)
66+
67+
input_ids = res["input_ids"].flatten()
68+
image_offsets = self.get_mm_items_offset(
69+
input_ids=input_ids,
70+
mm_token_id=_IMAGE_SPECIAL_TOKEN_ID,
71+
)
72+
73+
items = [
74+
MultimodalDataItem(
75+
pixel_values=res["input_image_embeds"],
76+
image_sizes=res["image_sizes"],
77+
image_emb_mask=res["image_attention_mask"],
78+
image_offsets=image_offsets,
79+
modality=Modality.IMAGE,
80+
)
81+
]
82+
83+
return {
84+
"mm_items": items,
85+
"input_ids": input_ids.tolist(),
86+
"im_token_id": _IMAGE_SPECIAL_TOKEN_ID,
87+
}

python/sglang/srt/models/minicpmv.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
# See the License for the specific language governing permissions and
2121
# limitations under the License.
2222
"""Inference-only MiniCPM-V model compatible with HuggingFace weights."""
23+
2324
from functools import partial
2425
from typing import (
2526
Any,
@@ -386,6 +387,7 @@ def __init__(
386387
self,
387388
config: PretrainedConfig,
388389
quant_config: Optional[QuantizationConfig] = None,
390+
require_post_norm: bool = True,
389391
prefix: str = "",
390392
) -> None:
391393
super().__init__()
@@ -398,20 +400,35 @@ def __init__(
398400
quant_config=quant_config,
399401
prefix=add_prefix("encoder", prefix),
400402
)
401-
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
403+
self.post_layernorm = (
404+
nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
405+
if require_post_norm
406+
else nn.Identity()
407+
)
402408

403409
def get_input_embeddings(self) -> nn.Embedding:
404410
return self.embeddings
405411

406-
def compute_cu_seqlens(self, tgt_sizes: torch.Tensor) -> torch.Tensor:
407-
patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1] # shape: (batch_size,)
412+
def compute_cu_seqlens(
413+
self,
414+
tgt_sizes: Optional[torch.Tensor] = None,
415+
atch_attention_mask: Optional[torch.BoolTensor] = None,
416+
) -> torch.Tensor:
417+
# shape: (batch_size,)
418+
if tgt_sizes is not None:
419+
patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1]
420+
else:
421+
patch_len = atch_attention_mask[:, :, 0].sum(dim=1) * atch_attention_mask[
422+
:, 0, :
423+
].sum(dim=1)
424+
408425
cu_seqlens = torch.cat(
409426
[
410427
torch.tensor([0], device=patch_len.device, dtype=torch.int32),
411428
torch.cumsum(patch_len, dim=0, dtype=torch.int32),
412429
],
413430
dim=0,
414-
).to(tgt_sizes.device)
431+
).to(patch_len.device)
415432
return cu_seqlens
416433

417434
def forward(
@@ -425,7 +442,7 @@ def forward(
425442
patch_attention_mask=patch_attention_mask,
426443
tgt_sizes=tgt_sizes,
427444
)
428-
cu_seqlens = self.compute_cu_seqlens(tgt_sizes)
445+
cu_seqlens = self.compute_cu_seqlens(tgt_sizes, patch_attention_mask)
429446
encoder_outputs = self.encoder(
430447
hidden_states,
431448
cu_seqlens=cu_seqlens,

0 commit comments

Comments
 (0)