1414from typing import List , Optional , Union
1515
1616import torch
17- from vllm .config import ModelConfig , SchedulerConfig
17+ from vllm .config import VllmConfig
1818from vllm .logger import init_logger
1919
2020from .base import ModelInputForRBLN , version_error
@@ -28,17 +28,15 @@ class RBLNOptimumEncoderDecoder(RBLNOptimumModelBase, RBLNOptimumDecoderMixin):
2828
2929 def __init__ (
3030 self ,
31- model_config : ModelConfig ,
32- scheduler_config : SchedulerConfig ,
31+ vllm_config : VllmConfig ,
3332 ) -> None :
34- super ().__init__ (model_config = model_config ,
35- scheduler_config = scheduler_config )
33+ super ().__init__ (vllm_config = vllm_config )
3634 # encoder length used for encoder_decoder architecture
3735 self .enc_lengths = [0 ] * self .batch_size
3836 self .setup_decoder_mixin (
3937 attn_impl = self .attn_impl ,
4038 padding_value = self .padding_value ,
41- vocab_size = model_config .get_vocab_size ,
39+ vocab_size = self . model_config .get_vocab_size ,
4240 use_multiple_decoder = False ,
4341 default_batch_size = self .scheduler_config .max_num_seqs ,
4442 decoder_batch_sizes = [self .batch_size ],
@@ -115,7 +113,8 @@ def _forward(
115113
116114 return logits
117115
118- def forward (self , model_input : ModelInputForRBLN ) -> torch .Tensor :
116+ def forward (self , model_input : ModelInputForRBLN ,
117+ ** kwargs ) -> torch .Tensor :
119118 input_ids = model_input .input_tokens
120119 cache_position = model_input .input_positions
121120 is_prompt = model_input .sampling_metadata .num_prompts > 0
0 commit comments