Skip to content

Commit 5db9a9a

Browse files
committed
Cleanups
Signed-off-by: rahul-tuli <rtuli@redhat.com>
1 parent ec01d6c commit 5db9a9a

File tree

1 file changed

+22
-25
lines changed

1 file changed

+22
-25
lines changed

tests/v1/spec_decode/test_acceptance_length.py

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
from dataclasses import dataclass, field
1313
from types import SimpleNamespace
14-
from typing import TypedDict
1514

1615
import pytest
1716
import torch
@@ -22,24 +21,20 @@
2221
from vllm.benchmarks.datasets import get_samples
2322
from vllm.inputs import TokensPrompt
2423
from vllm.platforms import current_platform
24+
from vllm.v1.attention.backends.registry import AttentionBackendEnum
2525
from vllm.v1.attention.selector import AttentionSelectorConfig
2626
from vllm.v1.metrics.reader import Counter, Vector
2727

2828

29-
class AcceptanceMetrics(TypedDict):
30-
acceptance_length: float
31-
acceptance_lengths_per_pos: list[float]
32-
num_drafts: int
33-
num_accepted_tokens: int
34-
35-
3629
@dataclass
3730
class Eagle3ModelConfig:
3831
verifier: str
3932
drafter: str
4033
expected_acceptance_length: float
4134
expected_acceptance_lengths_per_pos: list[float] = field(default_factory=list)
4235
id: str = ""
36+
# Backends that are incompatible with this model (will be skipped)
37+
excluded_backends: set[AttentionBackendEnum] = field(default_factory=set)
4338

4439

4540
# Model configurations for EAGLE3 acceptance length tests.
@@ -66,6 +61,9 @@ class Eagle3ModelConfig:
6661
expected_acceptance_length=2.56,
6762
expected_acceptance_lengths_per_pos=[0.7165, 0.5120, 0.3337],
6863
id="gpt-oss-20b-eagle3",
64+
# FLASHINFER incompatible: gpt-oss-20b uses sink attention which
65+
# FLASHINFER does not support ("sink setting not supported")
66+
excluded_backends={AttentionBackendEnum.FLASHINFER},
6967
),
7068
]
7169

@@ -81,11 +79,10 @@ class Eagle3ModelConfig:
8179

8280

8381
# Backends excluded from testing due to significantly different behavior
84-
EXCLUDED_BACKENDS = {"FLEX_ATTENTION"}
82+
EXCLUDED_BACKENDS = {AttentionBackendEnum.FLEX_ATTENTION}
8583

8684

8785
def get_available_attention_backends() -> list[str]:
88-
"""Get list of available attention backends for the current platform."""
8986
if not hasattr(current_platform, "get_valid_backends"):
9087
return ["FLASH_ATTN"]
9188

@@ -112,18 +109,15 @@ def get_available_attention_backends() -> list[str]:
112109
return [
113110
backend.name
114111
for backend, _ in valid_backends
115-
if backend.name not in EXCLUDED_BACKENDS
112+
if backend not in EXCLUDED_BACKENDS
116113
]
117114

118115

119-
def get_attention_backend_params() -> list[pytest.param]:
120-
"""Generate pytest params for available attention backends."""
121-
available = get_available_attention_backends()
122-
return [pytest.param(backend, id=backend.lower()) for backend in available]
116+
def get_attention_backend_params() -> list[str]:
117+
return get_available_attention_backends()
123118

124119

125120
def get_tp_size_params() -> list[pytest.param]:
126-
"""Generate pytest params for TP sizes based on available GPUs."""
127121
num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1
128122
return [pytest.param(tp, id=f"tp{tp}") for tp in TP_SIZES if tp <= num_gpus]
129123

@@ -157,8 +151,7 @@ def get_mt_bench_prompts(
157151
return prompt_ids
158152

159153

160-
def extract_acceptance_metrics(metrics, num_spec_tokens: int) -> AcceptanceMetrics:
161-
"""Extract acceptance length metrics from LLM metrics."""
154+
def extract_acceptance_metrics(metrics, num_spec_tokens: int) -> dict:
162155
num_drafts = 0
163156
num_accepted_tokens = 0
164157
acceptance_counts = [0] * num_spec_tokens
@@ -185,12 +178,12 @@ def extract_acceptance_metrics(metrics, num_spec_tokens: int) -> AcceptanceMetri
185178
count / num_drafts if num_drafts > 0 else 0.0 for count in acceptance_counts
186179
]
187180

188-
return AcceptanceMetrics(
189-
acceptance_length=acceptance_length,
190-
acceptance_lengths_per_pos=acceptance_lengths_per_pos,
191-
num_drafts=num_drafts,
192-
num_accepted_tokens=num_accepted_tokens,
193-
)
181+
return {
182+
"acceptance_length": acceptance_length,
183+
"acceptance_lengths_per_pos": acceptance_lengths_per_pos,
184+
"num_drafts": num_drafts,
185+
"num_accepted_tokens": num_accepted_tokens,
186+
}
194187

195188

196189
@large_gpu_mark(min_gb=40)
@@ -208,7 +201,11 @@ def test_eagle3_acceptance_length(
208201
attention_backend: str,
209202
monkeypatch: pytest.MonkeyPatch,
210203
):
211-
"""Test EAGLE3 acceptance length does not regress."""
204+
# Skip if this backend is incompatible with the model
205+
backend_enum = AttentionBackendEnum[attention_backend]
206+
if backend_enum in model_config.excluded_backends:
207+
pytest.skip(f"{attention_backend} is incompatible with {model_config.id}")
208+
212209
with monkeypatch.context() as m:
213210
m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
214211
m.setenv("VLLM_ATTENTION_BACKEND", attention_backend)

0 commit comments

Comments
 (0)