2323
2424from __future__ import annotations
2525
26+ import math
2627import os
28+ import re
2729from collections import Counter
30+ from dataclasses import dataclass
2831
2932import pytest
3033
4144
4245_NO_WINDOW = - 1
4346
44- # Position of ``sliding_window`` in ``paged_attention_primitive``'s
45- # positional signature (see ``attention_sdpa.py:489-510``).
46- _SLIDING_WINDOW_ARG_INDEX = 11
47+ # The bug only manifests once the prompt exceeds Gemma4 E2B's 512-token
48+ # sliding window. Build a prompt with a comfortable margin so the kernel
49+ # must actually decide whether to enforce the window.
50+ _LONG_CONTEXT_TOKEN_MARGIN = 128
51+ _LONG_CONTEXT_MIN_TOKENS = _E2B_SLIDING_WINDOW + _LONG_CONTEXT_TOKEN_MARGIN
52+ _LONG_CONTEXT_TARGET_TOKENS = _LONG_CONTEXT_MIN_TOKENS + 1
53+ _FRAGMENT_REPEAT_SAMPLE_COUNT = 2
54+ _MAX_MODEL_LEN = 1024
55+ _MAX_TOKENS = 1
56+ _PROMPT_FRAGMENT = "The capital of France is Paris. "
4757
4858# Ratio tolerance: layer_types is a config constant, but prefill and
4959# decode may dispatch slightly different counts across forwards, so we
5060# accept a 1% slack.
5161_RATIO_TOLERANCE = 0.01
5262
63+ _NB_PARAM_RE = re .compile (r"([A-Za-z_]\w*)\s*:" )
64+
65+
66+ @dataclass (frozen = True )
67+ class _KernelDispatch :
68+ sliding_window : int
69+ max_seq_len : int
70+
71+
72+ def _nanobind_param_indices (fn , * names : str ) -> dict [str , int ]:
73+ """Resolve parameter positions from nanobind's runtime signature metadata."""
74+ overloads = getattr (fn , "__nb_signature__" , ())
75+ if not overloads :
76+ raise RuntimeError ("paged_attention_primitive is missing __nb_signature__" )
77+
78+ signature_text = overloads [0 ][0 ]
79+ params_text = signature_text .partition ("(" )[2 ].rpartition (")" )[0 ]
80+ param_names = _NB_PARAM_RE .findall (params_text )
81+
82+ indices : dict [str , int ] = {}
83+ for name in names :
84+ if name not in param_names :
85+ raise RuntimeError (
86+ f"parameter { name !r} not found in nanobind signature: { signature_text } "
87+ )
88+ indices [name ] = param_names .index (name )
89+ return indices
90+
91+
92+ def _get_call_arg (
93+ args : tuple [object , ...],
94+ kwargs : dict [str , object ],
95+ param_indices : dict [str , int ],
96+ name : str ,
97+ ) -> object :
98+ """Read a native-op argument by name from positional/keyword call data."""
99+ index = param_indices [name ]
100+ if len (args ) > index :
101+ return args [index ]
102+ if name in kwargs :
103+ return kwargs [name ]
104+ raise RuntimeError (f"paged_attention_primitive call missing { name !r} " )
105+
106+
107+ def _build_long_prompt (tokenizer ) -> str :
108+ """Return a prompt whose tokenized length exceeds Gemma4's window size."""
109+ first_fragment_token_count = len (
110+ tokenizer .encode (text = _PROMPT_FRAGMENT , add_special_tokens = False )
111+ )
112+ if first_fragment_token_count <= 0 :
113+ raise AssertionError ("prompt fragment must tokenize to at least one token" )
114+
115+ repeated_fragment_sample = _PROMPT_FRAGMENT * _FRAGMENT_REPEAT_SAMPLE_COUNT
116+ repeat_increment_token_count = (
117+ len (tokenizer .encode (text = repeated_fragment_sample , add_special_tokens = False ))
118+ - first_fragment_token_count
119+ )
120+ if repeat_increment_token_count <= 0 :
121+ raise AssertionError (
122+ "prompt fragment repeat must increase token count by at least one token"
123+ )
124+
125+ additional_repeat_count = math .ceil (
126+ max (0 , _LONG_CONTEXT_TARGET_TOKENS - first_fragment_token_count )
127+ / repeat_increment_token_count
128+ )
129+ repeat_count = 1 + additional_repeat_count
130+ prompt = _PROMPT_FRAGMENT * repeat_count
131+ token_count = len (tokenizer .encode (text = prompt , add_special_tokens = False ))
132+ if token_count > _LONG_CONTEXT_MIN_TOKENS :
133+ return prompt
134+ raise AssertionError (
135+ "failed to construct a prompt longer than Gemma4's sliding window: "
136+ f"repeat_count={ repeat_count } , token_count={ token_count } , "
137+ f"first_fragment_token_count={ first_fragment_token_count } , "
138+ f"repeat_increment_token_count={ repeat_increment_token_count } , "
139+ f"target>{ _LONG_CONTEXT_MIN_TOKENS } "
140+ )
141+
53142
54143@pytest .fixture (scope = "module" )
55- def kernel_sliding_window_log () -> list [int ]:
144+ def kernel_dispatch_log () -> list [_KernelDispatch ]:
56145 """Run one Gemma4 inference with a spy on ``paged_attention_primitive``.
57146
58- Returns the list of ``sliding_window`` ints passed to every kernel
59- dispatch during the inference. Skips if the model path env var is
60- unset.
147+ Returns the ``sliding_window`` and ``max_seq_len`` seen by every
148+ kernel dispatch during one long-context inference. Skips if the
149+ model path env var is unset.
61150 """
62151 model_path = os .environ .get (MODEL_ENV )
63152 if not model_path :
@@ -76,24 +165,32 @@ def kernel_sliding_window_log() -> list[int]:
76165
77166 ops = get_ops ()
78167 orig_fn = ops .paged_attention_primitive
79- captured : list [int ] = []
168+ param_indices = _nanobind_param_indices (
169+ orig_fn , "sliding_window" , "max_seq_len"
170+ )
171+ captured : list [_KernelDispatch ] = []
80172
81173 def spy (* args , ** kwargs ):
82- sw = (
83- args [_SLIDING_WINDOW_ARG_INDEX ]
84- if len (args ) > _SLIDING_WINDOW_ARG_INDEX
85- else kwargs .get ("sliding_window" )
174+ captured .append (
175+ _KernelDispatch (
176+ sliding_window = int (
177+ _get_call_arg (args , kwargs , param_indices , "sliding_window" )
178+ ),
179+ max_seq_len = int (
180+ _get_call_arg (args , kwargs , param_indices , "max_seq_len" )
181+ ),
182+ )
86183 )
87- captured .append (sw )
88184 return orig_fn (* args , ** kwargs )
89185
90186 mp .setattr (ops , "paged_attention_primitive" , spy )
91187
92188 from vllm import LLM , SamplingParams
93189
94- llm = LLM (model = model_path , max_model_len = 512 , max_num_seqs = 1 )
95- sp = SamplingParams (temperature = 0 , max_tokens = 5 , ignore_eos = True )
96- llm .generate (["The capital of France is" ], sp )
190+ llm = LLM (model = model_path , max_model_len = _MAX_MODEL_LEN , max_num_seqs = 1 )
191+ prompt = _build_long_prompt (llm .get_tokenizer ())
192+ sp = SamplingParams (temperature = 0 , max_tokens = _MAX_TOKENS , ignore_eos = True )
193+ llm .generate ([prompt ], sp )
97194
98195 return captured
99196
@@ -103,26 +200,26 @@ class TestGemma4KernelReceivesPerLayerSlidingWindow:
103200 """Kernel-level assertions on the sliding_window values dispatched."""
104201
105202 def test_only_expected_window_values_appear (
106- self , kernel_sliding_window_log : list [int ]
203+ self , kernel_dispatch_log : list [_KernelDispatch ]
107204 ) -> None :
108205 """No stray values leak from wiring errors."""
109206 # Act
110207 unexpected = {
111- w
112- for w in kernel_sliding_window_log
113- if w not in (_E2B_SLIDING_WINDOW , _NO_WINDOW )
208+ dispatch . sliding_window
209+ for dispatch in kernel_dispatch_log
210+ if dispatch . sliding_window not in (_E2B_SLIDING_WINDOW , _NO_WINDOW )
114211 }
115212 # Assert
116213 assert not unexpected , (
117214 f"kernel received unexpected sliding_window values: { unexpected } "
118215 )
119216
120217 def test_both_sliding_and_full_layers_dispatch (
121- self , kernel_sliding_window_log : list [int ]
218+ self , kernel_dispatch_log : list [_KernelDispatch ]
122219 ) -> None :
123220 """``sliding_window=512`` and ``-1`` both appear."""
124221 # Act
125- counts = Counter (kernel_sliding_window_log )
222+ counts = Counter (dispatch . sliding_window for dispatch in kernel_dispatch_log )
126223 # Assert
127224 assert counts [_E2B_SLIDING_WINDOW ] > 0 , (
128225 "sliding layers never received their window -- enforcement is "
@@ -132,8 +229,18 @@ def test_both_sliding_and_full_layers_dispatch(
132229 "full layers never received -1 -- they may be incorrectly getting a window"
133230 )
134231
232+ def test_kernel_sees_context_longer_than_the_window (
233+ self , kernel_dispatch_log : list [_KernelDispatch ]
234+ ) -> None :
235+ """The regression test must actually exercise long-context behavior."""
236+ max_seen = max (dispatch .max_seq_len for dispatch in kernel_dispatch_log )
237+ assert max_seen > _E2B_SLIDING_WINDOW , (
238+ f"long-context path was not exercised: max_seq_len={ max_seen } , "
239+ f"sliding_window={ _E2B_SLIDING_WINDOW } "
240+ )
241+
135242 def test_ratio_matches_layer_types_config (
136- self , kernel_sliding_window_log : list [int ]
243+ self , kernel_dispatch_log : list [_KernelDispatch ]
137244 ) -> None :
138245 """Sliding/full dispatch ratio matches the 28:7 layer_types split.
139246
@@ -144,7 +251,7 @@ def test_ratio_matches_layer_types_config(
144251 ``layer_types`` and not stochastic.
145252 """
146253 # Act
147- counts = Counter (kernel_sliding_window_log )
254+ counts = Counter (dispatch . sliding_window for dispatch in kernel_dispatch_log )
148255 sliding = counts [_E2B_SLIDING_WINDOW ]
149256 full = counts [_NO_WINDOW ]
150257 total = sliding + full
0 commit comments