1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ from __future__ import annotations
16+
17+ from collections .abc import Callable
1518import functools
1619import jax
1720from jax import jit
2528# =============================================================================
2629
2730
28- def _gather_beams (x , beam_indices ) :
31+ def _gather_beams (x : jnp . ndarray , beam_indices : jnp . ndarray ) -> jnp . ndarray :
2932 """Efficiently gathers beam data across a batch during the selection step.
3033
31- This utility uses one-hot contraction for TPU efficiency to select the top-M
32- sequences from the candidate pool while preserving batch and history
33- dimensions.
34- It is designed to be highly efficient on hardware accelerators by avoiding
35- explicit Python loops.
34+ Uses one-hot contraction for TPU efficiency to select the top-M sequences
35+ from the candidate pool while preserving batch and history dimensions.
3636
3737 Args:
38- x (jnp.ndarray) : The source tensor to gather from, usually of shape
39- (batch_size, old_beam_size, ...).
40- beam_indices (jnp.ndarray) : The indices of the beams to select, of shape
41- (batch_size, new_beam_size).
38+ x: The source tensor to gather from.
39+ Shape: (batch_size, old_beam_size, ...).
40+ beam_indices: The indices of the beams to select.
41+ Shape: (batch_size, new_beam_size).
4242
4343 Returns:
44- jnp.ndarray: The gathered tensor of shape (batch_size, new_beam_size,
45- ...).
44+ The gathered tensor.
45+ Shape: (batch_size, new_beam_size, ...).
4646 """
47- # Extract dimensions from tensor shapes
4847 _ , old_beam_size = x .shape [:2 ]
4948
5049 # Create one-hot mask for the selection
@@ -72,35 +71,37 @@ def _gather_beams(x, beam_indices):
7271 ),
7372)
7473def generic_beam_search_jax (
75- model ,
76- key ,
77- mask_fn ,
78- batch_size ,
79- beam_size ,
80- tokens_per_beam ,
81- max_sample_len ,
82- start_token = 0 ,
83- ):
74+ model : Callable ,
75+ key : jax . random . PRNGKey ,
76+ mask_fn : Callable ,
77+ batch_size : int ,
78+ beam_size : int ,
79+ tokens_per_beam : int ,
80+ max_sample_len : int ,
81+ start_token : int = 0 ,
82+ ) -> jnp . ndarray :
8483 """A framework-agnostic harness for constrained beam search benchmarks.
8584
86- This function executes a standard autoregressive decoding loop but applies
87- a provided `mask_fn` at every step, including the initial root step. This
88- allows for a fair comparison between different constraint enforcement
89- algorithms (Trie, PPV, Hash) using identical selection and history management.
85+ Executes a standard autoregressive decoding loop but applies a provided
86+ `mask_fn` at every step, including the initial root step. This allows for a
87+ fair comparison between different constraint enforcement algorithms (Trie,
88+ PPV, Hash) using identical selection and history management.
9089
9190 Args:
92- model (Callable): A callable (or object) that takes (input_ids, key) and
93- returns (logits, new_key).
94- key (jax.random.PRNGKey): PRNG key for mock logit generation.
95- mask_fn (Callable): The baseline masking function to evaluate.
96- batch_size (int): Number of parallel sequences.
97- beam_size (int): Number of beams to maintain per sequence.
98- tokens_per_beam (int): Number of candidate tokens to consider per beam.
99- max_sample_len (int): The total decoding length (L).
100- start_token (int): Token ID used to initiate decoding (BOS/PAD).
91+ model: A callable that takes (input_ids, key) and returns
92+ (logits, new_key).
93+ key: PRNG key for mock logit generation.
94+ mask_fn: The baseline masking function to evaluate. Must accept
95+ (logprobs, history, step) and return masked logprobs.
96+ batch_size: Number of parallel sequences (B).
97+ beam_size: Number of beams to maintain per sequence (M).
98+ tokens_per_beam: Number of candidate tokens to consider per beam.
99+ max_sample_len: The total decoding length (L).
100+ start_token: Token ID used to initiate decoding (BOS/PAD).
101101
102102 Returns:
103- jnp.ndarray: The decoded sequences of shape (batch_size, beam_size, L).
103+ The decoded sequences.
104+ Shape: (batch_size, beam_size, max_sample_len).
104105 """
105106 # --- 1. INITIAL STEP (Root) ---
106107 # Create start tokens to prime the model
@@ -175,8 +176,16 @@ def generic_beam_search_jax(
175176# =============================================================================
176177
177178
178- def build_trie (sids ):
179- """Constructs a standard nested dictionary prefix tree on the CPU."""
179+ def build_trie (sids : np .ndarray ) -> dict :
180+ """Constructs a standard nested dictionary prefix tree on the CPU.
181+
182+ Args:
183+ sids: Array of Semantic IDs.
184+ Shape: (N, L).
185+
186+ Returns:
187+ The root node of the trie (nested dictionaries).
188+ """
180189 trie = {}
181190 for sid in sids :
182191 node = trie
@@ -188,16 +197,26 @@ def build_trie(sids):
188197 return trie
189198
190199
191- def make_trie_mask_fn (trie_root , vocab_size ):
200+ def make_trie_mask_fn (
201+ trie_root : dict , vocab_size : int
202+ ) -> Callable :
192203 """Creates a masking function that calls back to a CPU-based Trie.
193204
194205 This baseline simulates the "pointer-chasing" behavior common in many
195206 production systems. It uses `jax.pure_callback` to pause accelerator
196207 execution and retrieve a validity mask from CPU memory at every step.
208+
209+ Args:
210+ trie_root: The root node of the CPU trie.
211+ vocab_size: The token vocabulary size (V).
212+
213+ Returns:
214+ A JAX-compatible masking function accepting
215+ (logprobs, token_buffer, step).
197216 """
198217 all_tokens_set = set (range (vocab_size ))
199218
200- def python_callback (beams , step ) :
219+ def python_callback (beams : np . ndarray , step : int ) -> np . ndarray :
201220 """Internal Python logic to traverse the dictionary trie."""
202221 n = beams .shape [0 ]
203222 masks = np .zeros ((n , vocab_size ), dtype = bool )
@@ -244,8 +263,17 @@ def mask_fn(logprobs, token_buffer, step):
244263# =============================================================================
245264
246265
247- def build_hash_bitmap (sids ):
248- """Creates a static hash bitmap (Bloom filter style) for valid prefixes."""
266+ def build_hash_bitmap (sids : np .ndarray ) -> jnp .ndarray :
267+ """Creates a static hash bitmap (Bloom filter style) for valid prefixes.
268+
269+ Args:
270+ sids: Array of Semantic IDs.
271+ Shape: (N, L).
272+
273+ Returns:
274+ The bitmap array on device.
275+ Shape: (SIZE // 8,) where SIZE = 2^30.
276+ """
249277 BITMAP_BITS = 30
250278 SIZE = 1 << BITMAP_BITS
251279 MULTIPLIER = 0x9E371
@@ -266,11 +294,19 @@ def build_hash_bitmap(sids):
266294 return jnp .array (bitmap )
267295
268296
269- def make_hash_bitmap_fn (bitmap ) :
297+ def make_hash_bitmap_fn (bitmap : jnp . ndarray ) -> Callable :
270298 """Creates a hash-based masking function optimized for JAX.
271299
272300 This baseline is accelerator-native but suffers from potential false
273301 positives and lacks the O(1) candidate selection of the STATIC kernel.
302+
303+ Args:
304+ bitmap: The precomputed hash bitmap.
305+ Shape: (SIZE // 8,).
306+
307+ Returns:
308+ A JAX-compatible masking function accepting
309+ (logprobs, token_buffer, step).
274310 """
275311 BITMAP_BITS = 30
276312 SIZE = 1 << BITMAP_BITS
@@ -312,12 +348,32 @@ def hash_prefix(seq):
312348
313349
314350@functools .partial (jit , static_argnames = ["M" , "step" ])
315- def ppv_batch_logic (flat_logprobs , history , step , sorted_sids , M ):
351+ def ppv_batch_logic (
352+ flat_logprobs : jnp .ndarray ,
353+ history : jnp .ndarray ,
354+ step : int ,
355+ sorted_sids : jnp .ndarray ,
356+ M : int ,
357+ ) -> jnp .ndarray :
316358 """Implements the PPV (Parallel Prefix Verification) algorithm.
317359
318360 PPV performs binary search across a sorted list of Semantic IDs to validate
319361 individual candidate extensions. This baseline corresponds to the method
320362 described in Ye et al. [30].
363+
364+ Args:
365+ flat_logprobs: Log-probabilities for the current step.
366+ Shape: (batch_size, vocab_size).
367+ history: Token history for each beam.
368+ Shape: (batch_size, max_sample_len).
369+ step: The current decoding step index.
370+ sorted_sids: The sorted constraint set.
371+ Shape: (N, L).
372+ M: Number of top-k candidates to verify.
373+
374+ Returns:
375+ Boolean validity mask.
376+ Shape: (batch_size, vocab_size).
321377 """
322378 batch_size = flat_logprobs .shape [0 ]
323379 N = sorted_sids .shape [0 ]
@@ -391,8 +447,20 @@ def body_p2(state):
391447 return result_mask .at [jnp .arange (batch_size )[:, None ], vt ].max (final_valid )
392448
393449
394- def make_ppv_mask_fn (sorted_sids , top_k ):
395- """Creates a PPV-based masking function."""
450+ def make_ppv_mask_fn (
451+ sorted_sids : jnp .ndarray , top_k : int
452+ ) -> Callable :
453+ """Creates a PPV-based masking function.
454+
455+ Args:
456+ sorted_sids: The sorted constraint set on device.
457+ Shape: (N, L).
458+ top_k: Number of top candidates to verify per beam.
459+
460+ Returns:
461+ A JAX-compatible masking function accepting
462+ (logprobs, token_buffer, step).
463+ """
396464
397465 def mask_fn (flat_logprobs , token_buffer , step ):
398466 batch_size = flat_logprobs .shape [0 ]
0 commit comments