@@ -166,11 +166,18 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
166166 parallel_config = vllm_config .parallel_config
167167 scheduler_config = vllm_config .scheduler_config
168168 if cls .is_torch_compile :
169- if parallel_config .worker_cls == "auto" :
170- parallel_config .worker_cls = \
171- "vllm_rbln.worker.worker.RBLNWorker"
172- scheduler_config .scheduler_cls = \
173- "vllm_rbln.core.scheduler.RBLNScheduler"
169+ if envs .VLLM_USE_V1 :
170+ if parallel_config .worker_cls == "auto" :
171+ parallel_config .worker_cls = (
172+ "vllm_rbln.v1.worker.rbln_worker.RBLNWorker" )
173+ scheduler_config .scheduler_cls = (
174+ "vllm_rbln.v1.core.rbln_scheduler.RBLNScheduler" )
175+ else :
176+ if parallel_config .worker_cls == "auto" :
177+ parallel_config .worker_cls = (
178+ "vllm_rbln.worker.worker.RBLNWorker" )
179+ scheduler_config .scheduler_cls = (
180+ "vllm_rbln.core.scheduler.RBLNScheduler" )
174181 else :
175182 if envs .VLLM_USE_V1 :
176183 if parallel_config .worker_cls == "auto" :
@@ -204,6 +211,24 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
204211 "block_size must be configured for RBLN backend" )
205212 cache_config .enable_prefix_caching = False
206213
214+ if envs .VLLM_USE_V1 and cls .is_torch_compile :
215+ from vllm .config import CompilationLevel
216+
217+ if (vllm_config .compilation_config .level
218+ != CompilationLevel .NO_COMPILATION ):
219+ logger .info ("RBLN doesn't @support_torch_compile decorator" )
220+ vllm_config .compilation_config .level = (
221+ CompilationLevel .NO_COMPILATION )
222+ if (len (vllm_config .compilation_config .custom_ops ) == 1
223+ and vllm_config .compilation_config .custom_ops [0 ]
224+ == "none" ):
225+ vllm_config .compilation_config .custom_ops = []
226+
227+ if not model_config .disable_cascade_attn :
228+ logger .info ("The cascade attention is disabled"
229+ " because RBLN does not support it" )
230+ model_config .disable_cascade_attn = True
231+
207232 @classmethod
208233 def get_attn_backend_cls (
209234 cls ,
@@ -215,13 +240,16 @@ def get_attn_backend_cls(
215240 use_v1 : bool ,
216241 use_mla : bool ,
217242 ) -> str :
218- attn_backend_cls = (
219- "vllm_rbln.attention.backends.flash_attention.RBLNAttentionBackend"
220- )
243+ if envs .VLLM_USE_V1 :
244+ attn_backend_cls = ("vllm_rbln.v1.attention.backends."
245+ "flash_attention.RBLNAttentionBackend" )
246+ else :
247+ attn_backend_cls = ("vllm_rbln.attention.backends."
248+ "flash_attention.RBLNAttentionBackend" )
221249 logger .info ("Using RBLN Attention Backend: %s" , attn_backend_cls )
222250
223251 return attn_backend_cls
224252
225253 @classmethod
226254 def supports_v1 (cls , model_config : "ModelConfig" ) -> bool :
227- return not cls . is_torch_compile
255+ return True
0 commit comments