Skip to content

Commit a802389

Browse files
authored
model: support NVILA and NVILA Lite (sgl-project#10399)
1 parent 0103f37 commit a802389

8 files changed

Lines changed: 581 additions & 334 deletions

File tree

docs/supported_models/multimodal_language_models.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ in the GitHub search bar.
4444
| **GLM-4.5V** (106B) / **GLM-4.1V**(9B) | `zai-org/GLM-4.5V` | GLM-4.5V and GLM-4.1V-Thinking: Towards Versatile Multimodal Reasoning with Scalable Reinforcement Learning | Use `--chat-template glm-4v` |
4545
| **DotsVLM** (General/OCR) | `rednote-hilab/dots.vlm1.inst` | RedNote's vision-language model built on a 1.2B vision encoder and DeepSeek V3 LLM, featuring NaViT vision encoder trained from scratch with dynamic resolution support and enhanced OCR capabilities through structured image data training. | |
4646
| **DotsVLM-OCR** | `rednote-hilab/dots.ocr` | Specialized OCR variant of DotsVLM optimized for optical character recognition tasks with enhanced text extraction and document understanding capabilities. | Don't use `--trust-remote-code` |
47+
| **NVILA** (8B, 15B, Lite-2B, Lite-8B, Lite-15B) | `Efficient-Large-Model/NVILA-8B` | `chatml` | NVILA explores the full stack efficiency of multi-modal design, achieving cheaper training, faster deployment and better performance. |
4748

4849
## Usage Notes
4950

python/sglang/srt/configs/model_config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -914,12 +914,13 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
914914
"InternVLChatModel",
915915
"InternS1ForConditionalGeneration",
916916
"Phi4MMForCausalLM",
917-
"VILAForConditionalGeneration",
918917
"Step3VLForConditionalGeneration",
919918
"POINTSV15ChatModel",
920919
"DotsVLMForCausalLM",
921920
"DotsOCRForCausalLM",
922921
"Sarashina2VisionForCausalLM",
922+
"NVILAForConditionalGeneration",
923+
"NVILALiteForConditionalGeneration",
923924
"DeepseekOCRForCausalLM",
924925
]
925926

python/sglang/srt/models/nvila.py

Lines changed: 355 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,355 @@
1+
import itertools
2+
import math
3+
from collections.abc import Iterable
4+
from typing import Any
5+
6+
import einops
7+
import torch
8+
import torch.nn as nn
9+
import torch.nn.functional as F
10+
from torch import Tensor
11+
from transformers.configuration_utils import PretrainedConfig
12+
from transformers.modeling_outputs import BaseModelOutputWithPooling
13+
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
14+
from transformers.models.siglip import SiglipVisionConfig, SiglipVisionModel
15+
16+
import sglang.srt.managers.mm_utils as mm_utils
17+
import sglang.srt.model_loader.weight_utils as weight_utils
18+
import sglang.srt.utils as utils
19+
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
20+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
21+
from sglang.srt.managers.mm_utils import MultiModalityDataPaddingPatternMultimodalTokens
22+
from sglang.srt.managers.schedule_batch import (
23+
Modality,
24+
MultimodalDataItem,
25+
MultimodalInputs,
26+
)
27+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
28+
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
29+
30+
MM_HIDDEN_SIZE = 3456
31+
32+
33+
class NVILAConfig(PretrainedConfig):
34+
model_type = "nvila"
35+
sub_configs = {
36+
"text_config": Qwen2Config,
37+
"vision_config": SiglipVisionConfig,
38+
}
39+
_auto_class = "AutoConfig"
40+
41+
def __init__(
42+
self,
43+
*,
44+
text_config: dict[str, Any] | None = None,
45+
vision_config: dict[str, Any] | None = None,
46+
image_token_id: int | None = None,
47+
video_token_id: int | None = None,
48+
**kwargs,
49+
):
50+
self.text_config = (
51+
Qwen2Config(**text_config) if text_config is not None else Qwen2Config()
52+
)
53+
self.vision_config = (
54+
SiglipVisionConfig(**vision_config)
55+
if vision_config is not None
56+
else SiglipVisionConfig()
57+
)
58+
59+
self.image_token_id = image_token_id if image_token_id is not None else -1
60+
self.video_token_id = video_token_id if video_token_id is not None else -1
61+
62+
super().__init__(**kwargs)
63+
64+
65+
class NVILAMultiModalProjectorDownsampleBlock(nn.Module):
66+
def forward(self, x: Tensor) -> Tensor:
67+
batch_size, sequence_length, hidden_size = x.shape
68+
69+
feat_size = math.isqrt(sequence_length)
70+
71+
features = x.reshape(batch_size, feat_size, feat_size, hidden_size)
72+
73+
pad_after = feat_size % 2
74+
if pad_after > 0:
75+
features = F.pad(features, (0, 0, 0, pad_after, 0, pad_after))
76+
feat_size = feat_size + pad_after
77+
78+
features = features.reshape(
79+
batch_size, feat_size // 2, 2, feat_size // 2, 2, hidden_size
80+
)
81+
features = features.permute(0, 1, 3, 2, 4, 5).contiguous()
82+
features = features.reshape(batch_size, -1, 4 * hidden_size)
83+
84+
return features
85+
86+
87+
class NVILAMultiModalProjector(nn.Module):
88+
def __init__(self, config: NVILAConfig):
89+
super().__init__()
90+
91+
self.layers = nn.Sequential(
92+
NVILAMultiModalProjectorDownsampleBlock(),
93+
nn.LayerNorm(MM_HIDDEN_SIZE * 4),
94+
nn.Linear(MM_HIDDEN_SIZE * 4, config.text_config.hidden_size),
95+
nn.GELU(),
96+
nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size),
97+
)
98+
99+
def forward(self, x: Tensor) -> Tensor:
100+
return self.layers(x)
101+
102+
103+
class NVILAForConditionalGeneration(nn.Module):
104+
def __init__(
105+
self,
106+
config: NVILAConfig,
107+
quant_config: QuantizationConfig | None = None,
108+
prefix: str = "",
109+
) -> None:
110+
super().__init__()
111+
112+
self.config = config
113+
114+
self.vision_tower = SiglipVisionModel(config.vision_config)
115+
self.mm_projector = NVILAMultiModalProjector(config)
116+
self.llm = Qwen2ForCausalLM(
117+
config=config.text_config,
118+
quant_config=quant_config,
119+
prefix=utils.add_prefix("llm", prefix),
120+
)
121+
122+
def forward(
123+
self,
124+
input_ids: Tensor,
125+
positions: Tensor,
126+
forward_batch: ForwardBatch,
127+
get_embedding: bool = False,
128+
) -> LogitsProcessorOutput:
129+
output = mm_utils.general_mm_embed_routine(
130+
input_ids=input_ids,
131+
forward_batch=forward_batch,
132+
language_model=self.llm,
133+
data_embedding_funcs={
134+
Modality.IMAGE: self.get_image_feature,
135+
Modality.VIDEO: self.get_image_feature,
136+
},
137+
get_embedding=get_embedding,
138+
positions=positions,
139+
)
140+
141+
assert isinstance(output, LogitsProcessorOutput)
142+
143+
return output
144+
145+
def get_image_feature(self, mm_input: list[MultimodalDataItem]) -> Tensor:
146+
block_sizes = (
147+
list(
148+
itertools.chain.from_iterable(
149+
x.block_sizes for x in mm_input if hasattr(x, "block_sizes")
150+
)
151+
)
152+
or None
153+
)
154+
pixel_values = torch.cat([torch.tensor(x.feature) for x in mm_input], dim=0)
155+
156+
vision_tower_output: BaseModelOutputWithPooling = self.vision_tower(
157+
pixel_values.to(
158+
device=self.vision_tower.device, dtype=self.vision_tower.dtype
159+
),
160+
output_hidden_states=True,
161+
)
162+
assert vision_tower_output.hidden_states is not None
163+
164+
vision_features: Tensor = vision_tower_output.hidden_states[-2]
165+
166+
vision_features_list, block_sizes = merge_features_for_dynamic_s2(
167+
vision_features,
168+
block_sizes=(
169+
block_sizes
170+
if block_sizes is not None
171+
else [None] * vision_features.shape[0]
172+
),
173+
resize_output_to_scale_idx=-1,
174+
scales=[448, 896, 1344],
175+
)
176+
177+
vision_features_list = [
178+
split_chessboard(x, block_size[0], block_size[1])
179+
for x, block_size in zip(vision_features_list, block_sizes)
180+
]
181+
182+
vision_features = torch.cat(
183+
[einops.rearrange(x, "b c h w -> b (h w) c") for x in vision_features_list]
184+
)
185+
186+
vision_features = self.mm_projector(vision_features)
187+
188+
vision_features_list = list(
189+
vision_features.split(
190+
[block_size[0] * block_size[1] for block_size in block_sizes], dim=0
191+
)
192+
)
193+
vision_features_list = [
194+
merge_chessboard(x, block_size[0], block_size[1])
195+
for x, block_size in zip(vision_features_list, block_sizes)
196+
]
197+
198+
vision_features = torch.stack(
199+
[einops.rearrange(x, "1 c h w -> (h w) c") for x in vision_features_list]
200+
)
201+
202+
vision_features = einops.rearrange(vision_features, "n p d -> (n p) d")
203+
204+
return vision_features
205+
206+
def load_weights(self, weights: Iterable[tuple[str, Tensor]]) -> None:
207+
params_dict = dict(self.named_parameters())
208+
209+
for name, loaded_weight in weights:
210+
if name.startswith("llm."):
211+
self.llm.load_weights([(name[len("llm.") :], loaded_weight)])
212+
else:
213+
param = params_dict[name]
214+
weight_loader = getattr(
215+
param, "weight_loader", weight_utils.default_weight_loader
216+
)
217+
weight_loader(param, loaded_weight)
218+
219+
def pad_input_ids(
220+
self, input_ids: list[int], mm_inputs: MultimodalInputs
221+
) -> list[int]:
222+
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
223+
return pattern.pad_input_tokens(input_ids, mm_inputs)
224+
225+
226+
def merge_chessboard(x, num_split_h, num_split_w):
227+
"""
228+
x: b * n * c or b * h * w * c
229+
out: b * c * h * w
230+
Assuming x contains num_split**2 sub-squares concatenated along batch dimension, merge the sub-squares back to the original whole square.
231+
"""
232+
B = x.shape[0]
233+
if x.dim() == 3:
234+
N = x.shape[1]
235+
x = einops.rearrange(
236+
x, "b (h w) c -> b c h w", h=math.isqrt(N), w=math.isqrt(N)
237+
)
238+
239+
assert B % (num_split_h * num_split_w) == 0
240+
b = B // (num_split_h * num_split_w)
241+
242+
x_merge = torch.cat(
243+
[
244+
torch.cat(
245+
[
246+
x[(i * num_split_w + j) * b : (i * num_split_w + j + 1) * b]
247+
for j in range(num_split_w)
248+
],
249+
dim=-1,
250+
)
251+
for i in range(num_split_h)
252+
],
253+
dim=-2,
254+
)
255+
256+
return x_merge
257+
258+
259+
def merge_features_for_dynamic_s2(
260+
image_features, block_sizes, *, scales, resize_output_to_scale_idx
261+
):
262+
image_features_each_image = []
263+
new_block_sizes = []
264+
block_cnt = 0
265+
for block_size_each_image in block_sizes:
266+
if block_size_each_image is None:
267+
cur_features = image_features[block_cnt : block_cnt + 1]
268+
cur_features = einops.rearrange(
269+
cur_features,
270+
"1 (h w) c -> 1 c h w",
271+
h=math.isqrt(cur_features.shape[1]),
272+
)
273+
cur_features = cur_features.repeat(1, len(scales), 1, 1)
274+
image_features_each_image.append(cur_features)
275+
new_block_sizes.append((1, 1))
276+
block_cnt += 1
277+
else:
278+
cur_features_each_scale = []
279+
for scale in scales[:-1]:
280+
num_blocks_this_scale = (scale // scales[0]) ** 2
281+
cur_features_each_scale.append(
282+
merge_chessboard(
283+
image_features[block_cnt : block_cnt + num_blocks_this_scale],
284+
num_split_h=scale // scales[0],
285+
num_split_w=scale // scales[0],
286+
)
287+
) # 1 * C * H * W
288+
block_cnt += num_blocks_this_scale
289+
num_blocks_last_scale = block_size_each_image[0] * block_size_each_image[1]
290+
cur_features_each_scale.append(
291+
merge_chessboard(
292+
image_features[block_cnt : block_cnt + num_blocks_last_scale],
293+
num_split_h=block_size_each_image[0],
294+
num_split_w=block_size_each_image[1],
295+
)
296+
) # 1 * C * H * W
297+
block_cnt += num_blocks_last_scale
298+
299+
# resize and concat features from different scales
300+
output_size = cur_features_each_scale[resize_output_to_scale_idx].shape[-2:]
301+
cur_features = torch.cat(
302+
[
303+
F.interpolate(
304+
cur_features_each_scale[i].to(torch.float32),
305+
size=output_size,
306+
mode="area",
307+
).to(cur_features_each_scale[i].dtype)
308+
for i in range(len(cur_features_each_scale))
309+
],
310+
dim=1,
311+
)
312+
313+
image_features_each_image.append(cur_features)
314+
315+
if (
316+
resize_output_to_scale_idx == len(scales) - 1
317+
or resize_output_to_scale_idx == -1
318+
):
319+
new_block_sizes.append(block_size_each_image)
320+
else:
321+
new_block_sizes.append(
322+
(
323+
scales[resize_output_to_scale_idx] // scales[0],
324+
scales[resize_output_to_scale_idx] // scales[0],
325+
)
326+
)
327+
328+
assert block_cnt == len(
329+
image_features
330+
), f"The number of blocks ({block_cnt}) does not match length of image_features ({len(image_features)})!"
331+
332+
return image_features_each_image, new_block_sizes
333+
334+
335+
def split_chessboard(x, num_split_h, num_split_w):
336+
"""
337+
x: b * c * h * w
338+
out: b * c * h * w
339+
Deividing x into num_split**2 sub-squares, and concatenate all the sub-squares on the batch dimension
340+
"""
341+
B, C, H, W = x.shape
342+
assert H % num_split_h == 0 and W % num_split_w == 0
343+
h, w = H // num_split_h, W // num_split_w
344+
x_split = torch.cat(
345+
[
346+
x[:, :, i * h : (i + 1) * h, j * w : (j + 1) * w]
347+
for i in range(num_split_h)
348+
for j in range(num_split_w)
349+
],
350+
dim=0,
351+
)
352+
return x_split
353+
354+
355+
EntryClass = [NVILAForConditionalGeneration]

0 commit comments

Comments
 (0)