Skip to content

Commit b669f18

Browse files
authored
[Test] Gemma4 long-context sliding-window dispatch (#285)
Follow up on the review feedback for #282. - make `tests/test_gemma4_sliding_window_dispatch.py` build a prompt longer than Gemma4 E2B's 512-token sliding window so the test actually exercises the long-context path - record both `sliding_window` and `max_seq_len` from `paged_attention_primitive` and assert the kernel sees a context longer than the configured window - replace the hardcoded `paged_attention_primitive` argument index with lookup from nanobind signature metadata so the spy follows the canonical native-op signature --------- Signed-off-by: RickyChen / 陳昭儒 <ricky.chen@infinirc.com>
1 parent b24fa49 commit b669f18

1 file changed

Lines changed: 131 additions & 24 deletions

File tree

tests/test_gemma4_sliding_window_dispatch.py

Lines changed: 131 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,11 @@
2323

2424
from __future__ import annotations
2525

26+
import math
2627
import os
28+
import re
2729
from collections import Counter
30+
from dataclasses import dataclass
2831

2932
import pytest
3033

@@ -41,23 +44,109 @@
4144

4245
_NO_WINDOW = -1
4346

44-
# Position of ``sliding_window`` in ``paged_attention_primitive``'s
45-
# positional signature (see ``attention_sdpa.py:489-510``).
46-
_SLIDING_WINDOW_ARG_INDEX = 11
47+
# The bug only manifests once the prompt exceeds Gemma4 E2B's 512-token
48+
# sliding window. Build a prompt with a comfortable margin so the kernel
49+
# must actually decide whether to enforce the window.
50+
_LONG_CONTEXT_TOKEN_MARGIN = 128
51+
_LONG_CONTEXT_MIN_TOKENS = _E2B_SLIDING_WINDOW + _LONG_CONTEXT_TOKEN_MARGIN
52+
_LONG_CONTEXT_TARGET_TOKENS = _LONG_CONTEXT_MIN_TOKENS + 1
53+
_FRAGMENT_REPEAT_SAMPLE_COUNT = 2
54+
_MAX_MODEL_LEN = 1024
55+
_MAX_TOKENS = 1
56+
_PROMPT_FRAGMENT = "The capital of France is Paris. "
4757

4858
# Ratio tolerance: layer_types is a config constant, but prefill and
4959
# decode may dispatch slightly different counts across forwards, so we
5060
# accept a 1% slack.
5161
_RATIO_TOLERANCE = 0.01
5262

63+
_NB_PARAM_RE = re.compile(r"([A-Za-z_]\w*)\s*:")
64+
65+
66+
@dataclass(frozen=True)
67+
class _KernelDispatch:
68+
sliding_window: int
69+
max_seq_len: int
70+
71+
72+
def _nanobind_param_indices(fn, *names: str) -> dict[str, int]:
73+
"""Resolve parameter positions from nanobind's runtime signature metadata."""
74+
overloads = getattr(fn, "__nb_signature__", ())
75+
if not overloads:
76+
raise RuntimeError("paged_attention_primitive is missing __nb_signature__")
77+
78+
signature_text = overloads[0][0]
79+
params_text = signature_text.partition("(")[2].rpartition(")")[0]
80+
param_names = _NB_PARAM_RE.findall(params_text)
81+
82+
indices: dict[str, int] = {}
83+
for name in names:
84+
if name not in param_names:
85+
raise RuntimeError(
86+
f"parameter {name!r} not found in nanobind signature: {signature_text}"
87+
)
88+
indices[name] = param_names.index(name)
89+
return indices
90+
91+
92+
def _get_call_arg(
93+
args: tuple[object, ...],
94+
kwargs: dict[str, object],
95+
param_indices: dict[str, int],
96+
name: str,
97+
) -> object:
98+
"""Read a native-op argument by name from positional/keyword call data."""
99+
index = param_indices[name]
100+
if len(args) > index:
101+
return args[index]
102+
if name in kwargs:
103+
return kwargs[name]
104+
raise RuntimeError(f"paged_attention_primitive call missing {name!r}")
105+
106+
107+
def _build_long_prompt(tokenizer) -> str:
108+
"""Return a prompt whose tokenized length exceeds Gemma4's window size."""
109+
first_fragment_token_count = len(
110+
tokenizer.encode(text=_PROMPT_FRAGMENT, add_special_tokens=False)
111+
)
112+
if first_fragment_token_count <= 0:
113+
raise AssertionError("prompt fragment must tokenize to at least one token")
114+
115+
repeated_fragment_sample = _PROMPT_FRAGMENT * _FRAGMENT_REPEAT_SAMPLE_COUNT
116+
repeat_increment_token_count = (
117+
len(tokenizer.encode(text=repeated_fragment_sample, add_special_tokens=False))
118+
- first_fragment_token_count
119+
)
120+
if repeat_increment_token_count <= 0:
121+
raise AssertionError(
122+
"prompt fragment repeat must increase token count by at least one token"
123+
)
124+
125+
additional_repeat_count = math.ceil(
126+
max(0, _LONG_CONTEXT_TARGET_TOKENS - first_fragment_token_count)
127+
/ repeat_increment_token_count
128+
)
129+
repeat_count = 1 + additional_repeat_count
130+
prompt = _PROMPT_FRAGMENT * repeat_count
131+
token_count = len(tokenizer.encode(text=prompt, add_special_tokens=False))
132+
if token_count > _LONG_CONTEXT_MIN_TOKENS:
133+
return prompt
134+
raise AssertionError(
135+
"failed to construct a prompt longer than Gemma4's sliding window: "
136+
f"repeat_count={repeat_count}, token_count={token_count}, "
137+
f"first_fragment_token_count={first_fragment_token_count}, "
138+
f"repeat_increment_token_count={repeat_increment_token_count}, "
139+
f"target>{_LONG_CONTEXT_MIN_TOKENS}"
140+
)
141+
53142

54143
@pytest.fixture(scope="module")
55-
def kernel_sliding_window_log() -> list[int]:
144+
def kernel_dispatch_log() -> list[_KernelDispatch]:
56145
"""Run one Gemma4 inference with a spy on ``paged_attention_primitive``.
57146
58-
Returns the list of ``sliding_window`` ints passed to every kernel
59-
dispatch during the inference. Skips if the model path env var is
60-
unset.
147+
Returns the ``sliding_window`` and ``max_seq_len`` seen by every
148+
kernel dispatch during one long-context inference. Skips if the
149+
model path env var is unset.
61150
"""
62151
model_path = os.environ.get(MODEL_ENV)
63152
if not model_path:
@@ -76,24 +165,32 @@ def kernel_sliding_window_log() -> list[int]:
76165

77166
ops = get_ops()
78167
orig_fn = ops.paged_attention_primitive
79-
captured: list[int] = []
168+
param_indices = _nanobind_param_indices(
169+
orig_fn, "sliding_window", "max_seq_len"
170+
)
171+
captured: list[_KernelDispatch] = []
80172

81173
def spy(*args, **kwargs):
82-
sw = (
83-
args[_SLIDING_WINDOW_ARG_INDEX]
84-
if len(args) > _SLIDING_WINDOW_ARG_INDEX
85-
else kwargs.get("sliding_window")
174+
captured.append(
175+
_KernelDispatch(
176+
sliding_window=int(
177+
_get_call_arg(args, kwargs, param_indices, "sliding_window")
178+
),
179+
max_seq_len=int(
180+
_get_call_arg(args, kwargs, param_indices, "max_seq_len")
181+
),
182+
)
86183
)
87-
captured.append(sw)
88184
return orig_fn(*args, **kwargs)
89185

90186
mp.setattr(ops, "paged_attention_primitive", spy)
91187

92188
from vllm import LLM, SamplingParams
93189

94-
llm = LLM(model=model_path, max_model_len=512, max_num_seqs=1)
95-
sp = SamplingParams(temperature=0, max_tokens=5, ignore_eos=True)
96-
llm.generate(["The capital of France is"], sp)
190+
llm = LLM(model=model_path, max_model_len=_MAX_MODEL_LEN, max_num_seqs=1)
191+
prompt = _build_long_prompt(llm.get_tokenizer())
192+
sp = SamplingParams(temperature=0, max_tokens=_MAX_TOKENS, ignore_eos=True)
193+
llm.generate([prompt], sp)
97194

98195
return captured
99196

@@ -103,26 +200,26 @@ class TestGemma4KernelReceivesPerLayerSlidingWindow:
103200
"""Kernel-level assertions on the sliding_window values dispatched."""
104201

105202
def test_only_expected_window_values_appear(
106-
self, kernel_sliding_window_log: list[int]
203+
self, kernel_dispatch_log: list[_KernelDispatch]
107204
) -> None:
108205
"""No stray values leak from wiring errors."""
109206
# Act
110207
unexpected = {
111-
w
112-
for w in kernel_sliding_window_log
113-
if w not in (_E2B_SLIDING_WINDOW, _NO_WINDOW)
208+
dispatch.sliding_window
209+
for dispatch in kernel_dispatch_log
210+
if dispatch.sliding_window not in (_E2B_SLIDING_WINDOW, _NO_WINDOW)
114211
}
115212
# Assert
116213
assert not unexpected, (
117214
f"kernel received unexpected sliding_window values: {unexpected}"
118215
)
119216

120217
def test_both_sliding_and_full_layers_dispatch(
121-
self, kernel_sliding_window_log: list[int]
218+
self, kernel_dispatch_log: list[_KernelDispatch]
122219
) -> None:
123220
"""``sliding_window=512`` and ``-1`` both appear."""
124221
# Act
125-
counts = Counter(kernel_sliding_window_log)
222+
counts = Counter(dispatch.sliding_window for dispatch in kernel_dispatch_log)
126223
# Assert
127224
assert counts[_E2B_SLIDING_WINDOW] > 0, (
128225
"sliding layers never received their window -- enforcement is "
@@ -132,8 +229,18 @@ def test_both_sliding_and_full_layers_dispatch(
132229
"full layers never received -1 -- they may be incorrectly getting a window"
133230
)
134231

232+
def test_kernel_sees_context_longer_than_the_window(
233+
self, kernel_dispatch_log: list[_KernelDispatch]
234+
) -> None:
235+
"""The regression test must actually exercise long-context behavior."""
236+
max_seen = max(dispatch.max_seq_len for dispatch in kernel_dispatch_log)
237+
assert max_seen > _E2B_SLIDING_WINDOW, (
238+
f"long-context path was not exercised: max_seq_len={max_seen}, "
239+
f"sliding_window={_E2B_SLIDING_WINDOW}"
240+
)
241+
135242
def test_ratio_matches_layer_types_config(
136-
self, kernel_sliding_window_log: list[int]
243+
self, kernel_dispatch_log: list[_KernelDispatch]
137244
) -> None:
138245
"""Sliding/full dispatch ratio matches the 28:7 layer_types split.
139246
@@ -144,7 +251,7 @@ def test_ratio_matches_layer_types_config(
144251
``layer_types`` and not stochastic.
145252
"""
146253
# Act
147-
counts = Counter(kernel_sliding_window_log)
254+
counts = Counter(dispatch.sliding_window for dispatch in kernel_dispatch_log)
148255
sliding = counts[_E2B_SLIDING_WINDOW]
149256
full = counts[_NO_WINDOW]
150257
total = sliding + full

0 commit comments

Comments
 (0)