1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414from dataclasses import dataclass
15- from typing import Any , Dict , List , Optional , Tuple
15+ from typing import Any , Dict , List , Optional , Tuple , cast
1616
1717import torch
1818from vllm .config import ModelConfig , SchedulerConfig
2121 Gemma3ImagePixelInputs )
2222
2323from .base import ModelInputForRBLN , version_error
24- from .model_base import RBLNOptimumDecoderMixin , RBLNOptimumModelBase
24+ from .model_base import (RBLNOptimumDecoderMixin , RBLNOptimumDictTableMixin ,
25+ RBLNOptimumModelBase )
2526
2627logger = init_logger (__name__ )
2728
@@ -34,7 +35,8 @@ class SlidingWindowEntry:
3435
3536
3637class RBLNOptimumGemma3ForConditionalGeneration (RBLNOptimumModelBase ,
37- RBLNOptimumDecoderMixin ):
38+ RBLNOptimumDecoderMixin ,
39+ RBLNOptimumDictTableMixin ):
3840
3941 def __init__ (
4042 self ,
@@ -120,49 +122,37 @@ def select_local_block_table_value(
120122 running_requests_ids : list [str ],
121123 finished_requests_ids : list [str ],
122124 ) -> Tuple [list [int ], list [int ], list [torch .Tensor ]]:
123- if is_prompt :
124- # Generate attention mask without padding
125- attention_mask = torch .ones_like (input_ids ).squeeze (0 )
126-
127- # Determine sliding_window_table_id
128- # FIXME:
129- # finished_requests_ids is typed as list[str],
130- # but used as list[int].
131- if finished_requests_ids :
132- first_id = finished_requests_ids [0 ]
133- local_table_id = self .sliding_window_table [
134- first_id ].local_table_id
135-
136- for request_id in finished_requests_ids :
137- self .sliding_window_table .pop (request_id )
138- else :
139- used_ids = {
140- v .local_table_id
141- for v in self .sliding_window_table .values ()
142- }
143- available_ids = set (range (self .decoder_batch_size )) - used_ids
144- assert len (available_ids ) > 0
145- local_table_id = min (available_ids )
146-
147- if len (self .sliding_window_table ) > self .decoder_batch_size :
148- raise ValueError (
149- "Sliding window table size must not exceed the batch size."
150- )
151125
152- return [local_table_id ], [], [attention_mask ]
126+ get_extra_values_fn = None
127+ attention_mask = None
153128
129+ if is_prompt :
130+ attention_mask = torch .ones_like (input_ids ).squeeze (0 )
154131 else :
155- local_table_ids : List [int ] = []
156- padded_cache_lengths : List [int ] = []
157- attention_masks : List [torch .Tensor ] = []
132+ get_extra_values_fn = lambda entry : (
133+ entry .padded_cache_length ,
134+ entry .attention_mask ,
135+ )
158136
159- for request_id in running_requests_ids :
160- sliding_window = self .sliding_window_table [request_id ]
161- local_table_ids .append (sliding_window .local_table_id )
162- padded_cache_lengths .append (sliding_window .padded_cache_length )
163- attention_masks .append (sliding_window .attention_mask )
137+ result = self .get_table_mapping_values (
138+ self .sliding_window_table ,
139+ self .decoder_batch_size ,
140+ is_prompt ,
141+ finished_requests_ids ,
142+ running_requests_ids ,
143+ get_entry_fn = lambda entry : entry .local_table_id ,
144+ get_extra_values_fn = get_extra_values_fn ,
145+ )
164146
165- return local_table_ids , padded_cache_lengths , attention_masks
147+ if is_prompt :
148+ result = cast (list [int ], result )
149+ table_ids = result
150+ return table_ids , [], [attention_mask ]
151+ else :
152+ result = cast (Tuple [list [int ], list [int ], list [torch .Tensor ]],
153+ result )
154+ table_ids , padded_cache_lengths , attention_masks = result
155+ return table_ids , padded_cache_lengths , attention_masks
166156
167157 def get_pixel_values (self , model_input : ModelInputForRBLN ):
168158 image_input = None
0 commit comments