Skip to content

Commit a1235ee

Browse files
authored
Support for Whisper model in vLLM (#6)
1 parent cc81794 commit a1235ee

7 files changed

Lines changed: 347 additions & 43 deletions

File tree

examples/optimum/run_whisper.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# Copyright 2025 Rebellions Inc. All rights reserved.
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at:
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import asyncio
15+
16+
import fire
17+
from datasets import load_dataset
18+
from transformers import AutoTokenizer
19+
from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams
20+
21+
22+
def generate_prompts(batch_size: int, model_id: str):
23+
dataset = load_dataset("distil-whisper/librispeech_asr-noise",
24+
"test-pub-noise",
25+
split="40")
26+
27+
messages = [{
28+
"prompt": "<|startoftranscript|>",
29+
"multi_modal_data": {
30+
"audio": (dataset[i]["audio"]["array"],
31+
dataset[i]["audio"]["sampling_rate"])
32+
},
33+
} for i in range(batch_size)]
34+
35+
return messages
36+
37+
38+
async def generate(engine: AsyncLLMEngine, tokenizer, request_id, request):
39+
results_generator = engine.generate(
40+
request,
41+
SamplingParams(temperature=0,
42+
ignore_eos=False,
43+
skip_special_tokens=True,
44+
stop_token_ids=[tokenizer.eos_token_id],
45+
max_tokens=448),
46+
request_id,
47+
)
48+
49+
final_output = None
50+
async for request_output in results_generator:
51+
final_output = request_output
52+
return final_output
53+
54+
55+
async def main(
56+
batch_size: int,
57+
max_seq_len: int,
58+
num_input_prompt: int,
59+
model_id: str,
60+
):
61+
engine_args = AsyncEngineArgs(model=model_id,
62+
device="auto",
63+
max_num_seqs=batch_size,
64+
max_num_batched_tokens=max_seq_len,
65+
max_model_len=max_seq_len,
66+
block_size=max_seq_len,
67+
limit_mm_per_prompt={"audio": 1})
68+
69+
engine = AsyncLLMEngine.from_engine_args(engine_args)
70+
tokenizer = AutoTokenizer.from_pretrained(model_id)
71+
inputs = generate_prompts(num_input_prompt, model_id)
72+
73+
futures = []
74+
for request_id, request in enumerate(inputs):
75+
futures.append(
76+
asyncio.create_task(
77+
generate(engine, tokenizer, request_id, request)))
78+
79+
results = await asyncio.gather(*futures)
80+
81+
for i, result in enumerate(results):
82+
output = result.outputs[0].text
83+
print(
84+
f"===================== Output {i} ==============================")
85+
print(output)
86+
print(
87+
"===============================================================\n"
88+
)
89+
90+
91+
def entry_point(
92+
batch_size: int = 4,
93+
max_seq_len: int = 448,
94+
num_input_prompt: int = 1,
95+
model_id: str = "/whisper-base-b4-wo-token-timestamps",
96+
):
97+
loop = asyncio.get_event_loop()
98+
loop.run_until_complete(
99+
main(
100+
batch_size=batch_size,
101+
max_seq_len=max_seq_len,
102+
num_input_prompt=num_input_prompt,
103+
model_id=model_id,
104+
))
105+
106+
107+
if __name__ == "__main__":
108+
fire.Fire(entry_point)

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@ datasets
55
qwen_vl_utils
66
transformers==4.51.3
77
vllm==0.9.1
8-
optimum-rbln>=0.8.1
8+
optimum-rbln>=0.8.2a4

vllm_rbln/model_executor/models/optimum/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
RBLNOptimumLlavaNextForConditionalGeneration)
3131
from .qwen2_5_vl import ( # noqa: F401
3232
RBLNOptimumQwen2_5_VLForConditionalGeneration)
33+
from .whisper import RBLNOptimumWhisperForConditionalGeneration # noqa: F401
3334

3435
logger = init_logger(__name__)
3536

vllm_rbln/model_executor/models/optimum/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@
5858
("blip2", "RBLNBlip2ForConditionalGeneration"),
5959
"Gemma3ForConditionalGeneration": ("gemma3",
6060
"RBLNGemma3ForConditionalGeneration"),
61+
"WhisperForConditionalGeneration": ("whisper",
62+
"RBLNWhisperForConditionalGeneration"),
6163
}
6264

6365
_RBLN_EMBEDDING_MODELS = {

vllm_rbln/model_executor/models/optimum/gemma3.py

Lines changed: 31 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from dataclasses import dataclass
15-
from typing import Any, Dict, List, Optional, Tuple
15+
from typing import Any, Dict, List, Optional, Tuple, cast
1616

1717
import torch
1818
from vllm.config import ModelConfig, SchedulerConfig
@@ -21,7 +21,8 @@
2121
Gemma3ImagePixelInputs)
2222

2323
from .base import ModelInputForRBLN, version_error
24-
from .model_base import RBLNOptimumDecoderMixin, RBLNOptimumModelBase
24+
from .model_base import (RBLNOptimumDecoderMixin, RBLNOptimumDictTableMixin,
25+
RBLNOptimumModelBase)
2526

2627
logger = init_logger(__name__)
2728

@@ -34,7 +35,8 @@ class SlidingWindowEntry:
3435

3536

3637
class RBLNOptimumGemma3ForConditionalGeneration(RBLNOptimumModelBase,
37-
RBLNOptimumDecoderMixin):
38+
RBLNOptimumDecoderMixin,
39+
RBLNOptimumDictTableMixin):
3840

3941
def __init__(
4042
self,
@@ -120,49 +122,37 @@ def select_local_block_table_value(
120122
running_requests_ids: list[str],
121123
finished_requests_ids: list[str],
122124
) -> Tuple[list[int], list[int], list[torch.Tensor]]:
123-
if is_prompt:
124-
# Generate attention mask without padding
125-
attention_mask = torch.ones_like(input_ids).squeeze(0)
126-
127-
# Determine sliding_window_table_id
128-
# FIXME:
129-
# finished_requests_ids is typed as list[str],
130-
# but used as list[int].
131-
if finished_requests_ids:
132-
first_id = finished_requests_ids[0]
133-
local_table_id = self.sliding_window_table[
134-
first_id].local_table_id
135-
136-
for request_id in finished_requests_ids:
137-
self.sliding_window_table.pop(request_id)
138-
else:
139-
used_ids = {
140-
v.local_table_id
141-
for v in self.sliding_window_table.values()
142-
}
143-
available_ids = set(range(self.decoder_batch_size)) - used_ids
144-
assert len(available_ids) > 0
145-
local_table_id = min(available_ids)
146-
147-
if len(self.sliding_window_table) > self.decoder_batch_size:
148-
raise ValueError(
149-
"Sliding window table size must not exceed the batch size."
150-
)
151125

152-
return [local_table_id], [], [attention_mask]
126+
get_extra_values_fn = None
127+
attention_mask = None
153128

129+
if is_prompt:
130+
attention_mask = torch.ones_like(input_ids).squeeze(0)
154131
else:
155-
local_table_ids: List[int] = []
156-
padded_cache_lengths: List[int] = []
157-
attention_masks: List[torch.Tensor] = []
132+
get_extra_values_fn = lambda entry: (
133+
entry.padded_cache_length,
134+
entry.attention_mask,
135+
)
158136

159-
for request_id in running_requests_ids:
160-
sliding_window = self.sliding_window_table[request_id]
161-
local_table_ids.append(sliding_window.local_table_id)
162-
padded_cache_lengths.append(sliding_window.padded_cache_length)
163-
attention_masks.append(sliding_window.attention_mask)
137+
result = self.get_table_mapping_values(
138+
self.sliding_window_table,
139+
self.decoder_batch_size,
140+
is_prompt,
141+
finished_requests_ids,
142+
running_requests_ids,
143+
get_entry_fn=lambda entry: entry.local_table_id,
144+
get_extra_values_fn=get_extra_values_fn,
145+
)
164146

165-
return local_table_ids, padded_cache_lengths, attention_masks
147+
if is_prompt:
148+
result = cast(list[int], result)
149+
table_ids = result
150+
return table_ids, [], [attention_mask]
151+
else:
152+
result = cast(Tuple[list[int], list[int], list[torch.Tensor]],
153+
result)
154+
table_ids, padded_cache_lengths, attention_masks = result
155+
return table_ids, padded_cache_lengths, attention_masks
166156

167157
def get_pixel_values(self, model_input: ModelInputForRBLN):
168158
image_input = None

vllm_rbln/model_executor/models/optimum/model_base.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import os
1616
from functools import cache
1717
from pathlib import Path
18-
from typing import Optional
18+
from typing import Any, Callable, Dict, Optional, Tuple, Union
1919

2020
import optimum.rbln
2121
import torch
@@ -238,3 +238,64 @@ def select_lower_bounded_batch_size(self, original_batch_size: int,
238238
decoder_batch_sizes: tuple):
239239
index = bisect.bisect_left(decoder_batch_sizes, original_batch_size)
240240
return decoder_batch_sizes[index]
241+
242+
243+
class RBLNOptimumDictTableMixin:
244+
"""
245+
Mixin for models using a request-ID keyed table implemented as a dictionary.
246+
"""
247+
248+
def get_table_mapping_values(
249+
self,
250+
table_mapping: Dict[str, Any],
251+
decoder_batch_size: int,
252+
is_prompt: bool,
253+
finished_requests_ids: list[str],
254+
running_requests_ids: list[str],
255+
get_entry_fn: Optional[Callable[[Any], Any]] = None,
256+
get_extra_values_fn: Optional[Callable[[Any],
257+
Union[Any, Tuple[Any,
258+
...]]]] = None,
259+
) -> Union[list[int], Tuple[list[int], ...]]:
260+
if is_prompt:
261+
if finished_requests_ids:
262+
first_id = finished_requests_ids[0]
263+
first_entry = table_mapping[first_id]
264+
table_id = get_entry_fn(
265+
first_entry) if get_entry_fn else first_entry
266+
267+
for request_id in finished_requests_ids:
268+
table_mapping.pop(request_id)
269+
else:
270+
used_ids = {
271+
get_entry_fn(v) if get_entry_fn else v
272+
for v in table_mapping.values()
273+
}
274+
available_ids = set(range(decoder_batch_size)) - used_ids
275+
assert available_ids, "No available table IDs"
276+
table_id = min(available_ids)
277+
return [table_id]
278+
279+
table_ids = []
280+
extra_values = []
281+
282+
for request_id in running_requests_ids:
283+
entry = table_mapping[request_id]
284+
table_id = get_entry_fn(entry) if get_entry_fn else entry
285+
table_ids.append(table_id)
286+
287+
if get_extra_values_fn:
288+
result = get_extra_values_fn(entry)
289+
if not isinstance(result, tuple):
290+
result = (result, )
291+
extra_values.append(result)
292+
293+
if get_extra_values_fn:
294+
extra_values_lists: list[list[Any]] = [
295+
list(col) for col in zip(*extra_values)
296+
]
297+
return (table_ids, *extra_values_lists)
298+
return table_ids
299+
300+
def clear_table_mapping(self, table_mapping: Dict[str, Any]):
301+
table_mapping.clear()

0 commit comments

Comments
 (0)