1111
1212from dataclasses import dataclass , field
1313from types import SimpleNamespace
14- from typing import TypedDict
1514
1615import pytest
1716import torch
2221from vllm .benchmarks .datasets import get_samples
2322from vllm .inputs import TokensPrompt
2423from vllm .platforms import current_platform
24+ from vllm .v1 .attention .backends .registry import AttentionBackendEnum
2525from vllm .v1 .attention .selector import AttentionSelectorConfig
2626from vllm .v1 .metrics .reader import Counter , Vector
2727
2828
29- class AcceptanceMetrics (TypedDict ):
30- acceptance_length : float
31- acceptance_lengths_per_pos : list [float ]
32- num_drafts : int
33- num_accepted_tokens : int
34-
35-
3629@dataclass
3730class Eagle3ModelConfig :
3831 verifier : str
3932 drafter : str
4033 expected_acceptance_length : float
4134 expected_acceptance_lengths_per_pos : list [float ] = field (default_factory = list )
4235 id : str = ""
36+ # Backends that are incompatible with this model (will be skipped)
37+ excluded_backends : set [AttentionBackendEnum ] = field (default_factory = set )
4338
4439
4540# Model configurations for EAGLE3 acceptance length tests.
@@ -66,6 +61,9 @@ class Eagle3ModelConfig:
6661 expected_acceptance_length = 2.56 ,
6762 expected_acceptance_lengths_per_pos = [0.7165 , 0.5120 , 0.3337 ],
6863 id = "gpt-oss-20b-eagle3" ,
64+ # FLASHINFER incompatible: gpt-oss-20b uses sink attention which
65+ # FLASHINFER does not support ("sink setting not supported")
66+ excluded_backends = {AttentionBackendEnum .FLASHINFER },
6967 ),
7068]
7169
@@ -81,11 +79,10 @@ class Eagle3ModelConfig:
8179
8280
8381# Backends excluded from testing due to significantly different behavior
84- EXCLUDED_BACKENDS = {" FLEX_ATTENTION" }
82+ EXCLUDED_BACKENDS = {AttentionBackendEnum . FLEX_ATTENTION }
8583
8684
8785def get_available_attention_backends () -> list [str ]:
88- """Get list of available attention backends for the current platform."""
8986 if not hasattr (current_platform , "get_valid_backends" ):
9087 return ["FLASH_ATTN" ]
9188
@@ -112,18 +109,15 @@ def get_available_attention_backends() -> list[str]:
112109 return [
113110 backend .name
114111 for backend , _ in valid_backends
115- if backend . name not in EXCLUDED_BACKENDS
112+ if backend not in EXCLUDED_BACKENDS
116113 ]
117114
118115
119- def get_attention_backend_params () -> list [pytest .param ]:
120- """Generate pytest params for available attention backends."""
121- available = get_available_attention_backends ()
122- return [pytest .param (backend , id = backend .lower ()) for backend in available ]
116+ def get_attention_backend_params () -> list [str ]:
117+ return get_available_attention_backends ()
123118
124119
125120def get_tp_size_params () -> list [pytest .param ]:
126- """Generate pytest params for TP sizes based on available GPUs."""
127121 num_gpus = torch .cuda .device_count () if torch .cuda .is_available () else 1
128122 return [pytest .param (tp , id = f"tp{ tp } " ) for tp in TP_SIZES if tp <= num_gpus ]
129123
@@ -157,8 +151,7 @@ def get_mt_bench_prompts(
157151 return prompt_ids
158152
159153
160- def extract_acceptance_metrics (metrics , num_spec_tokens : int ) -> AcceptanceMetrics :
161- """Extract acceptance length metrics from LLM metrics."""
154+ def extract_acceptance_metrics (metrics , num_spec_tokens : int ) -> dict :
162155 num_drafts = 0
163156 num_accepted_tokens = 0
164157 acceptance_counts = [0 ] * num_spec_tokens
@@ -185,12 +178,12 @@ def extract_acceptance_metrics(metrics, num_spec_tokens: int) -> AcceptanceMetri
185178 count / num_drafts if num_drafts > 0 else 0.0 for count in acceptance_counts
186179 ]
187180
188- return AcceptanceMetrics (
189- acceptance_length = acceptance_length ,
190- acceptance_lengths_per_pos = acceptance_lengths_per_pos ,
191- num_drafts = num_drafts ,
192- num_accepted_tokens = num_accepted_tokens ,
193- )
181+ return {
182+ " acceptance_length" : acceptance_length ,
183+ " acceptance_lengths_per_pos" : acceptance_lengths_per_pos ,
184+ " num_drafts" : num_drafts ,
185+ " num_accepted_tokens" : num_accepted_tokens ,
186+ }
194187
195188
196189@large_gpu_mark (min_gb = 40 )
@@ -208,7 +201,11 @@ def test_eagle3_acceptance_length(
208201 attention_backend : str ,
209202 monkeypatch : pytest .MonkeyPatch ,
210203):
211- """Test EAGLE3 acceptance length does not regress."""
204+ # Skip if this backend is incompatible with the model
205+ backend_enum = AttentionBackendEnum [attention_backend ]
206+ if backend_enum in model_config .excluded_backends :
207+ pytest .skip (f"{ attention_backend } is incompatible with { model_config .id } " )
208+
212209 with monkeypatch .context () as m :
213210 m .setenv ("VLLM_ALLOW_INSECURE_SERIALIZATION" , "1" )
214211 m .setenv ("VLLM_ATTENTION_BACKEND" , attention_backend )
0 commit comments