@@ -257,7 +257,6 @@ def set_forward_context(
257257 batch_descriptor : BatchDescriptor | None = None ,
258258 ubatch_slices : UBatchSlices | None = None ,
259259 slot_mapping : dict [str , torch .Tensor ] | list [dict [str , torch .Tensor ]] | None = None ,
260- additional_kwargs : dict [str , Any ] | None = None ,
261260 skip_compiled : bool = False ,
262261):
263262 """A context manager that stores the current forward context,
@@ -297,7 +296,7 @@ def set_forward_context(
297296 if cudagraph_runtime_mode != CUDAGraphMode .NONE and num_tokens is not None :
298297 batch_descriptor = batch_descriptor or BatchDescriptor (num_tokens = num_tokens )
299298
300- platform_additional_kwargs = current_platform .set_additional_forward_context (
299+ additional_kwargs = current_platform .set_additional_forward_context (
301300 attn_metadata = attn_metadata ,
302301 vllm_config = vllm_config ,
303302 dp_metadata = dp_metadata ,
@@ -307,9 +306,6 @@ def set_forward_context(
307306 batch_descriptor = batch_descriptor ,
308307 ubatch_slices = ubatch_slices ,
309308 )
310- merged_additional_kwargs = dict (platform_additional_kwargs )
311- if additional_kwargs :
312- merged_additional_kwargs .update (additional_kwargs )
313309
314310 forward_context = create_forward_context (
315311 attn_metadata ,
@@ -319,7 +315,7 @@ def set_forward_context(
319315 batch_descriptor ,
320316 ubatch_slices ,
321317 slot_mapping ,
322- merged_additional_kwargs ,
318+ additional_kwargs ,
323319 skip_compiled ,
324320 )
325321
0 commit comments