Skip to content

Commit 484b9af

Browse files
Fixes: gemma3's block_table scheduling bug due to padding (#28)
* fix shape of pixel_values * pre-commit * forward with padded * change the parameter and register gemma3 * allow kwargs for register vllm model * pre-commit * import error * revert examples * RBLNGemma3MultiModalProcessor's location * pre-commit * clean * clean * pre-commit * pre-commit * pre-commit * fix param * fix pre-commit * fixme comment * fix pre-commit * rebase * fix typo --------- Co-authored-by: rebel-eunji <eunji.lee@rebellions.ai>
1 parent 1a13e2c commit 484b9af

16 files changed

Lines changed: 184 additions & 124 deletions

File tree

vllm_rbln/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,7 @@ def register_model():
3030
)
3131
ModelRegistry.register_model("T5EncoderModel",
3232
"optimum.rbln:RBLNT5EncoderModel")
33+
ModelRegistry.register_model(
34+
"Gemma3ForConditionalGeneration",
35+
"vllm_rbln.model_executor.models.optimum.gemma3:RBLNOptimumGemma3ForConditionalGeneration"
36+
)

vllm_rbln/model_executor/model_loader/rbln_model_loader.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import torch.nn as nn
15-
from vllm.config import ModelConfig, SchedulerConfig
15+
from vllm.config import VllmConfig
1616

1717
from vllm_rbln.model_executor.models.optimum import load_model
1818

1919

20-
def get_optimum_model(
21-
model_config: ModelConfig,
22-
scheduler_config: SchedulerConfig,
23-
) -> nn.Module:
24-
return load_model(model_config, scheduler_config)
20+
def get_optimum_model(vllm_config: VllmConfig, ) -> nn.Module:
21+
return load_model(vllm_config)

vllm_rbln/model_executor/models/optimum/__init__.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
"""Utilities for selecting and loading rbln models."""
1515

1616
import torch.nn as nn
17-
from vllm.config import ModelConfig, SchedulerConfig
17+
from vllm.config import VllmConfig
1818
from vllm.logger import init_logger
1919

2020
from .base import (_RBLN_MULTIMODAL_MODELS, ModelInputForRBLN,
@@ -41,31 +41,26 @@
4141
}
4242

4343

44-
def load_model(
45-
model_config: ModelConfig,
46-
scheduler_config: SchedulerConfig,
47-
) -> nn.Module:
44+
def load_model(vllm_config: VllmConfig) -> nn.Module:
45+
model_config = vllm_config.model_config
46+
4847
if is_multi_modal(model_config.hf_config):
4948
architectures = getattr(model_config.hf_config, "architectures", [])
5049
if architectures[0] in _RBLN_OPTIMUM_MULTIMODAL_MODELS:
5150
rbln_model_arch = _RBLN_OPTIMUM_MULTIMODAL_MODELS[architectures[0]]
52-
rbln_model = rbln_model_arch(model_config=model_config,
53-
scheduler_config=scheduler_config)
51+
rbln_model = rbln_model_arch(vllm_config)
5452
else:
5553
raise NotImplementedError(
5654
f"Model architectures {architectures} are "
5755
f"not supported on RBLN Optimum for now. "
5856
"Supported multimodal architectures: "
5957
f"{list(_RBLN_OPTIMUM_MULTIMODAL_MODELS.keys())}")
6058
elif is_enc_dec_arch(model_config.hf_config):
61-
rbln_model = RBLNOptimumEncoderDecoder(
62-
model_config=model_config, scheduler_config=scheduler_config)
59+
rbln_model = RBLNOptimumEncoderDecoder(vllm_config)
6360
elif is_pooling_arch(model_config.hf_config):
64-
rbln_model = RBLNOptimumForEncoderModel(
65-
model_config=model_config, scheduler_config=scheduler_config)
61+
rbln_model = RBLNOptimumForEncoderModel(vllm_config)
6662
else:
67-
rbln_model = RBLNOptimumForCausalLM(model_config=model_config,
68-
scheduler_config=scheduler_config)
63+
rbln_model = RBLNOptimumForCausalLM(vllm_config)
6964
return rbln_model.eval()
7065

7166

vllm_rbln/model_executor/models/optimum/blip2.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from typing import Any, Optional
1515

1616
import torch
17-
from vllm.config import ModelConfig, SchedulerConfig
17+
from vllm.config import VllmConfig
1818
from vllm.logger import init_logger
1919
from vllm.model_executor.models.blip2 import (Blip2ImageEmbeddingInputs,
2020
Blip2ImageInputs,
@@ -32,23 +32,22 @@ class RBLNOptimumBlip2ForConditionalGeneration(RBLNOptimumModelBase,
3232

3333
def __init__(
3434
self,
35-
model_config: ModelConfig,
36-
scheduler_config: SchedulerConfig,
35+
vllm_config: VllmConfig,
3736
) -> None:
38-
super().__init__(model_config=model_config,
39-
scheduler_config=scheduler_config)
37+
super().__init__(vllm_config=vllm_config)
4038
self.setup_decoder_mixin(
4139
attn_impl=self.attn_impl,
4240
padding_value=self.padding_value,
43-
vocab_size=model_config.get_vocab_size,
41+
vocab_size=self.model_config.get_vocab_size,
4442
use_multiple_decoder=getattr(self.model.rbln_config.language_model,
4543
"use_multiple_decoder", False),
4644
default_batch_size=self.scheduler_config.max_num_seqs,
4745
decoder_batch_sizes=self.model.rbln_config.language_model.
4846
decoder_batch_sizes,
4947
)
5048

51-
def forward(self, model_input: ModelInputForRBLN) -> torch.Tensor:
49+
def forward(self, model_input: ModelInputForRBLN,
50+
**kwargs) -> torch.Tensor:
5251
input_ids = model_input.input_tokens
5352
cache_position = model_input.input_positions
5453
block_tables = model_input.block_tables

vllm_rbln/model_executor/models/optimum/decoder_only.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import torch
15-
from vllm.config import ModelConfig, SchedulerConfig
15+
from vllm.config import VllmConfig
1616
from vllm.logger import init_logger
1717

1818
from .base import ModelInputForRBLN, version_error
@@ -25,25 +25,21 @@ class RBLNOptimumForCausalLM(RBLNOptimumModelBase, RBLNOptimumDecoderMixin):
2525

2626
def __init__(
2727
self,
28-
model_config: ModelConfig,
29-
scheduler_config: SchedulerConfig,
30-
**kwargs,
28+
vllm_config: VllmConfig,
3129
) -> None:
32-
super().__init__(
33-
model_config=model_config,
34-
scheduler_config=scheduler_config,
35-
)
30+
super().__init__(vllm_config=vllm_config)
3631
self.setup_decoder_mixin(
3732
attn_impl=self.attn_impl,
3833
padding_value=self.padding_value,
39-
vocab_size=model_config.get_vocab_size,
34+
vocab_size=self.model_config.get_vocab_size,
4035
use_multiple_decoder=getattr(self.model.rbln_config,
4136
"use_multiple_decoder", False),
4237
default_batch_size=self.scheduler_config.max_num_seqs,
4338
decoder_batch_sizes=self.model.rbln_config.decoder_batch_sizes,
4439
)
4540

46-
def forward(self, model_input: ModelInputForRBLN) -> torch.Tensor:
41+
def forward(self, model_input: ModelInputForRBLN,
42+
**kwargs) -> torch.Tensor:
4743
input_ids = model_input.input_tokens
4844
cache_position = model_input.input_positions
4945
block_tables = model_input.block_tables

vllm_rbln/model_executor/models/optimum/encoder.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from typing import Optional
1616

1717
import torch
18-
from vllm.config import ModelConfig, PoolerConfig, SchedulerConfig
18+
from vllm.config import PoolerConfig, VllmConfig
1919
from vllm.logger import init_logger
2020
from vllm.model_executor.layers.pooler import Pooler, PoolingType
2121
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
@@ -31,11 +31,10 @@ class RBLNOptimumForEncoderModel(RBLNOptimumModelBase):
3131

3232
def __init__(
3333
self,
34-
model_config: ModelConfig,
35-
scheduler_config: SchedulerConfig,
34+
vllm_config: VllmConfig,
3635
) -> None:
37-
super().__init__(model_config, scheduler_config)
38-
self._pooler = self._build_pooler(model_config.pooler_config)
36+
super().__init__(vllm_config=vllm_config)
37+
self._pooler = self._build_pooler(self.model_config.pooler_config)
3938

4039
def is_classification_arch(self):
4140
architectures = getattr(
@@ -94,7 +93,8 @@ def _build_pooler(self, pooler_config: PoolerConfig) -> Optional[Pooler]:
9493
)
9594
return None
9695

97-
def forward(self, model_input: ModelInputForRBLN) -> torch.Tensor:
96+
def forward(self, model_input: ModelInputForRBLN,
97+
**kwargs) -> torch.Tensor:
9898
input_ids, token_type_ids, positions = self.preprocess(
9999
model_input.input_tokens,
100100
model_input.token_type_ids,

vllm_rbln/model_executor/models/optimum/encoder_decoder.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from typing import List, Optional, Union
1515

1616
import torch
17-
from vllm.config import ModelConfig, SchedulerConfig
17+
from vllm.config import VllmConfig
1818
from vllm.logger import init_logger
1919

2020
from .base import ModelInputForRBLN, version_error
@@ -28,17 +28,15 @@ class RBLNOptimumEncoderDecoder(RBLNOptimumModelBase, RBLNOptimumDecoderMixin):
2828

2929
def __init__(
3030
self,
31-
model_config: ModelConfig,
32-
scheduler_config: SchedulerConfig,
31+
vllm_config: VllmConfig,
3332
) -> None:
34-
super().__init__(model_config=model_config,
35-
scheduler_config=scheduler_config)
33+
super().__init__(vllm_config=vllm_config)
3634
# encoder length used for encoder_decoder architecture
3735
self.enc_lengths = [0] * self.batch_size
3836
self.setup_decoder_mixin(
3937
attn_impl=self.attn_impl,
4038
padding_value=self.padding_value,
41-
vocab_size=model_config.get_vocab_size,
39+
vocab_size=self.model_config.get_vocab_size,
4240
use_multiple_decoder=False,
4341
default_batch_size=self.scheduler_config.max_num_seqs,
4442
decoder_batch_sizes=[self.batch_size],
@@ -115,7 +113,8 @@ def _forward(
115113

116114
return logits
117115

118-
def forward(self, model_input: ModelInputForRBLN) -> torch.Tensor:
116+
def forward(self, model_input: ModelInputForRBLN,
117+
**kwargs) -> torch.Tensor:
119118
input_ids = model_input.input_tokens
120119
cache_position = model_input.input_positions
121120
is_prompt = model_input.sampling_metadata.num_prompts > 0

0 commit comments

Comments
 (0)