Skip to content

Commit b2531e3

Browse files
authored
Support for LlavaForConditionalGeneration models (#31)
1 parent c538cfa commit b2531e3

5 files changed

Lines changed: 322 additions & 2 deletions

File tree

examples/optimum/run_pixtral.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# Copyright 2025 Rebellions Inc. All rights reserved.
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at:
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import asyncio
16+
17+
import fire
18+
from datasets import load_dataset
19+
from transformers import AutoProcessor, AutoTokenizer
20+
from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams
21+
22+
23+
def generate_prompts(batch_size: int, model_id: str):
24+
dataset = load_dataset("HuggingFaceM4/ChartQA",
25+
split="train").shuffle(seed=42)
26+
processor = AutoProcessor.from_pretrained(model_id, padding_side="left")
27+
messages = [[
28+
{
29+
"role":
30+
"user",
31+
"content": [
32+
{
33+
"type": "image"
34+
},
35+
{
36+
"type": "text",
37+
"text": dataset[i]["query"],
38+
},
39+
],
40+
},
41+
] for i in range(batch_size)]
42+
images = [[dataset[i]["image"]] for i in range(batch_size)]
43+
texts = processor.apply_chat_template(
44+
messages,
45+
add_generation_prompt=True,
46+
tokenize=False,
47+
)
48+
49+
inputs = [{
50+
"prompt": text,
51+
"multi_modal_data": {
52+
"image": image
53+
}
54+
} for text, image in zip(texts, images)]
55+
labels = [dataset[i]["label"] for i in range(batch_size)]
56+
return inputs, labels
57+
58+
59+
async def generate(engine: AsyncLLMEngine, tokenizer, request_id, request):
60+
results_generator = engine.generate(
61+
request,
62+
SamplingParams(temperature=0,
63+
ignore_eos=False,
64+
skip_special_tokens=True,
65+
stop_token_ids=[tokenizer.eos_token_id],
66+
max_tokens=500),
67+
str(request_id),
68+
)
69+
70+
final_output = None
71+
async for request_output in results_generator:
72+
final_output = request_output
73+
return final_output
74+
75+
76+
async def main(
77+
batch_size: int,
78+
max_seq_len: int,
79+
kvcache_partition_len: int,
80+
num_input_prompt: int,
81+
model_id: str,
82+
):
83+
engine_args = AsyncEngineArgs(model=model_id,
84+
device="auto",
85+
max_num_seqs=batch_size,
86+
max_num_batched_tokens=max_seq_len,
87+
max_model_len=max_seq_len,
88+
block_size=kvcache_partition_len)
89+
90+
engine = AsyncLLMEngine.from_engine_args(engine_args)
91+
tokenizer = AutoTokenizer.from_pretrained(model_id)
92+
inputs, labels = generate_prompts(num_input_prompt, model_id)
93+
94+
futures = []
95+
for request_id, request in enumerate(inputs):
96+
futures.append(
97+
asyncio.create_task(
98+
generate(engine, tokenizer, request_id, request)))
99+
100+
results = await asyncio.gather(*futures)
101+
for i, (result, label) in enumerate(zip(results, labels)):
102+
label_str = str(label)
103+
output = result.outputs[0].text
104+
105+
print("=" * 80)
106+
print(f"[{i}] Label:")
107+
print(f"{label_str}\n")
108+
print(f"[{i}] Model Output:")
109+
print(output)
110+
print("=" * 80 + "\n")
111+
112+
113+
def entry_point(
114+
batch_size: int = 4,
115+
max_seq_len: int = 131072,
116+
kvcache_partition_len: int = 16384,
117+
num_input_prompt: int = 4,
118+
model_id: str = "/pixtral-12b-b4",
119+
):
120+
loop = asyncio.get_event_loop()
121+
loop.run_until_complete(
122+
main(
123+
batch_size=batch_size,
124+
max_seq_len=max_seq_len,
125+
kvcache_partition_len=kvcache_partition_len,
126+
num_input_prompt=num_input_prompt,
127+
model_id=model_id,
128+
))
129+
130+
131+
if __name__ == "__main__":
132+
fire.Fire(entry_point)

vllm_rbln/model_executor/models/optimum/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from .encoder_decoder import RBLNOptimumEncoderDecoder
2727
from .gemma3 import RBLNOptimumGemma3ForConditionalGeneration # noqa: F401
2828
from .idefics3 import RBLNOptimumIdefics3ForConditionalGeneration # noqa: F401
29+
from .llava import RBLNOptimumLlavaForConditionalGeneration # noqa: F401
2930
from .llava_next import ( # noqa: F401
3031
RBLNOptimumLlavaNextForConditionalGeneration)
3132
from .model_base import RBLNOptimumDictTableMixin

vllm_rbln/model_executor/models/optimum/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@
6161
"RBLNGemma3ForConditionalGeneration"),
6262
"WhisperForConditionalGeneration": ("whisper",
6363
"RBLNWhisperForConditionalGeneration"),
64+
"LlavaForConditionalGeneration": ("llava",
65+
"RBLNLlavaForConditionalGeneration"),
6466
}
6567

6668
_RBLN_EMBEDDING_MODELS = {
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
# Copyright 2025 Rebellions Inc. All rights reserved.
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at:
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import Any, List, Optional, Union
15+
16+
import torch
17+
import vllm.envs as env
18+
from vllm.config import VllmConfig
19+
from vllm.logger import init_logger
20+
from vllm.model_executor.models.llava import (LlavaImageEmbeddingInputs,
21+
LlavaImageInputs,
22+
LlavaImagePixelInputs,
23+
PixtralHFImagePixelInputs)
24+
from vllm.model_executor.models.utils import flatten_bn
25+
26+
from .base import ModelInputForRBLN, version_error
27+
from .model_base import RBLNOptimumDecoderMixin, RBLNOptimumModelBase
28+
29+
logger = init_logger(__name__)
30+
31+
32+
class RBLNOptimumLlavaForConditionalGeneration(RBLNOptimumModelBase,
33+
RBLNOptimumDecoderMixin):
34+
35+
def __init__(
36+
self,
37+
vllm_config: VllmConfig,
38+
) -> None:
39+
super().__init__(vllm_config=vllm_config)
40+
self.setup_decoder_mixin(
41+
attn_impl=self.attn_impl,
42+
padding_value=self.padding_value,
43+
vocab_size=self.model_config.get_vocab_size,
44+
use_multiple_decoder=getattr(self.model.rbln_config.language_model,
45+
"use_multiple_decoder", False),
46+
default_batch_size=self.scheduler_config.max_num_seqs,
47+
decoder_batch_sizes=self.model.rbln_config.language_model.
48+
decoder_batch_sizes,
49+
)
50+
51+
def _forward(
52+
self,
53+
is_prefill: bool,
54+
block_tables: torch.Tensor,
55+
input_ids: torch.LongTensor = None,
56+
pixel_values: torch.FloatTensor = None,
57+
image_sizes: Optional[torch.LongTensor] = None,
58+
inputs_embeds: Optional[torch.FloatTensor] = None,
59+
vision_feature_layer: Optional[int] = None,
60+
vision_feature_select_strategy: Optional[str] = None,
61+
cache_position: Union[List[torch.Tensor],
62+
torch.Tensor] = None, # vllm keyword argument
63+
**kwargs,
64+
):
65+
if is_prefill:
66+
inputs_embeds = self.model._preprocess_prefill(
67+
input_ids=input_ids,
68+
inputs_embeds=inputs_embeds,
69+
pixel_values=pixel_values,
70+
image_sizes=image_sizes,
71+
)
72+
if self.model.language_model.prefill_decoder is None:
73+
raise version_error
74+
75+
logits = self.model.language_model.prefill_decoder(
76+
inputs_embeds=inputs_embeds,
77+
cache_position=cache_position,
78+
block_tables=block_tables,
79+
).logits
80+
else:
81+
if self.model.language_model.decoder is None:
82+
raise version_error
83+
84+
logits = self.model.language_model.decoder(
85+
input_ids=input_ids,
86+
cache_position=cache_position,
87+
block_tables=block_tables,
88+
).logits
89+
90+
return logits
91+
92+
def forward(self, model_input: ModelInputForRBLN,
93+
**kwargs) -> torch.Tensor:
94+
input_ids = model_input.input_tokens
95+
cache_position = model_input.input_positions
96+
block_tables = model_input.block_tables
97+
98+
if env.VLLM_USE_V1:
99+
is_prompt = model_input.is_prompt
100+
else:
101+
is_prompt = model_input.sampling_metadata.num_prompts > 0
102+
103+
request_nums = input_ids.shape[0]
104+
if model_input.multi_modal_kwargs:
105+
image_input = self._parse_and_validate_image_input(
106+
**model_input.multi_modal_kwargs)
107+
if image_input is not None:
108+
if image_input["type"] == "pixel_values":
109+
pixel_values = image_input["pixel_values"]
110+
image_sizes = None
111+
elif image_input["type"] == "pixel_values_pixtral":
112+
pixel_values = image_input["pixel_values"]
113+
image_sizes = torch.tensor(
114+
pixel_values.shape[-2:]).unsqueeze(0)
115+
else:
116+
pixel_values = None
117+
image_sizes = None
118+
119+
kwargs = self.preprocess_for_decoder(
120+
is_prompt,
121+
block_tables,
122+
input_ids,
123+
cache_position,
124+
)
125+
input_ids = kwargs.pop("input_ids")
126+
cache_position = kwargs.pop("cache_position")
127+
block_tables = kwargs.pop("block_tables")
128+
if not is_prompt:
129+
padded_batch_size = kwargs.pop("padded_batch_size",
130+
self.decoder_batch_size)
131+
self.model.language_model.decoder = \
132+
self.model.language_model.decoders[padded_batch_size]
133+
134+
logits = self._forward(
135+
is_prefill=is_prompt,
136+
block_tables=block_tables,
137+
input_ids=input_ids,
138+
cache_position=cache_position,
139+
pixel_values=pixel_values,
140+
image_sizes=image_sizes,
141+
)
142+
143+
if not is_prompt:
144+
logits = logits[:request_nums]
145+
return logits
146+
147+
def _parse_and_validate_image_input(
148+
self, **kwargs: Any) -> Optional[LlavaImageInputs]:
149+
pixel_values = kwargs.pop("pixel_values", None)
150+
image_embeds = kwargs.pop("image_embeds", None)
151+
152+
if pixel_values is None and image_embeds is None:
153+
return None
154+
155+
if pixel_values is not None:
156+
if not isinstance(pixel_values, (torch.Tensor, list)):
157+
raise ValueError("Incorrect type of pixel values. "
158+
f"Got type: {type(pixel_values)}")
159+
160+
# Pixtral
161+
if hasattr(self.model.rbln_config.vision_tower, "max_image_size"):
162+
return PixtralHFImagePixelInputs(
163+
type="pixel_values_pixtral",
164+
pixel_values=flatten_bn(pixel_values),
165+
)
166+
167+
return LlavaImagePixelInputs(
168+
type="pixel_values",
169+
pixel_values=flatten_bn(pixel_values, concat=True),
170+
)
171+
172+
if image_embeds is not None:
173+
if not isinstance(image_embeds, (torch.Tensor, list)):
174+
raise ValueError("Incorrect type of image embeddings. "
175+
f"Got type: {type(image_embeds)}")
176+
177+
if self.config.vision_config.model_type == "pixtral":
178+
raise ValueError("Pixtral-HF does not support image_embeds.")
179+
180+
return LlavaImageEmbeddingInputs(
181+
type="image_embeds",
182+
data=flatten_bn(image_embeds, concat=True),
183+
)
184+
185+
raise AssertionError("This line should be unreachable.")

vllm_rbln/model_executor/models/optimum/llava_next.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def forward(self, model_input: ModelInputForRBLN,
139139
**model_input.multi_modal_kwargs)
140140
if image_input is not None:
141141
assert image_input["type"] == "pixel_values"
142-
pixel_values = image_input["data"]
142+
pixel_values = image_input["pixel_values"]
143143
image_sizes = image_input["image_sizes"]
144144
else:
145145
pixel_values = None
@@ -193,7 +193,7 @@ def _parse_and_validate_image_input(
193193

194194
return LlavaNextImagePixelInputs(
195195
type="pixel_values",
196-
data=flatten_bn(pixel_values),
196+
pixel_values=flatten_bn(pixel_values),
197197
image_sizes=flatten_bn(image_sizes),
198198
)
199199

0 commit comments

Comments
 (0)