@@ -203,18 +203,15 @@ def forward(
203203 attention_mask : PredefinedAttentionMask = PredefinedAttentionMask .
204204 CAUSAL ,
205205 all_reduce_params : Optional [AllReduceParams ] = None ,
206- lora_params : Optional [dict ] = None ,
207206 ** kwargs ,
208207 ) -> torch .Tensor :
209- assert lora_params is None , "LORA is not supported for Llama4Attention"
210208 if self .use_rope :
211209 return super ().forward (
212210 position_ids = position_ids ,
213211 hidden_states = hidden_states ,
214212 attn_metadata = attn_metadata ,
215213 attention_mask = attention_mask ,
216214 all_reduce_params = all_reduce_params ,
217- lora_params = lora_params ,
218215 ** kwargs ,
219216 )
220217 else :
@@ -481,7 +478,6 @@ def forward(
481478 attn_metadata : AttentionMetadata ,
482479 residual : Optional [torch .Tensor ],
483480 spec_metadata : Optional [SpecMetadata ] = None ,
484- lora_params : Optional [dict ] = None ,
485481 ** kwargs ,
486482 ) -> torch .Tensor :
487483 # Only enable min-latency mode on Blackwell
@@ -506,7 +502,6 @@ def forward(
506502 attn_metadata = attn_metadata ,
507503 all_reduce_params = AllReduceParams (
508504 enable_allreduce = not self .disable_attn_allreduce ),
509- lora_params = lora_params ,
510505 ** kwargs ,
511506 )
512507
@@ -547,7 +542,6 @@ def forward(
547542 final_all_reduce_params = AllReduceParams (
548543 enable_allreduce = not self .disable_feed_forward_allreduce ),
549544 cutlass_min_latency_mode = cutlass_min_latency_mode ,
550- lora_params = lora_params ,
551545 )
552546
553547 if spec_metadata is not None :
0 commit comments