Skip to content

Commit ee5c348

Browse files
yueqiwisaykatsman
andauthored
Readability improvements on variable names, type hints, doc strings, etc (#1)
* readability improvements on variable names and type hints * add .DS_Store to gitignore --------- Co-authored-by: Isay Katsman <isaykatsman@google.com>
1 parent 3ee2abe commit ee5c348

12 files changed

Lines changed: 407 additions & 348 deletions

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,5 @@
1414
**/dist/
1515
**/*.egg-info/
1616

17+
# Ignore DS_Store files
1718
**/.DS_Store

benchmarks/baselines_jax.py

Lines changed: 114 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
16+
17+
from collections.abc import Callable
1518
import functools
1619
import jax
1720
from jax import jit
@@ -25,26 +28,22 @@
2528
# =============================================================================
2629

2730

28-
def _gather_beams(x, beam_indices):
31+
def _gather_beams(x: jnp.ndarray, beam_indices: jnp.ndarray) -> jnp.ndarray:
2932
"""Efficiently gathers beam data across a batch during the selection step.
3033
31-
This utility uses one-hot contraction for TPU efficiency to select the top-M
32-
sequences from the candidate pool while preserving batch and history
33-
dimensions.
34-
It is designed to be highly efficient on hardware accelerators by avoiding
35-
explicit Python loops.
34+
Uses one-hot contraction for TPU efficiency to select the top-M sequences
35+
from the candidate pool while preserving batch and history dimensions.
3636
3737
Args:
38-
x (jnp.ndarray): The source tensor to gather from, usually of shape
39-
(batch_size, old_beam_size, ...).
40-
beam_indices (jnp.ndarray): The indices of the beams to select, of shape
41-
(batch_size, new_beam_size).
38+
x: The source tensor to gather from.
39+
Shape: (batch_size, old_beam_size, ...).
40+
beam_indices: The indices of the beams to select.
41+
Shape: (batch_size, new_beam_size).
4242
4343
Returns:
44-
jnp.ndarray: The gathered tensor of shape (batch_size, new_beam_size,
45-
...).
44+
The gathered tensor.
45+
Shape: (batch_size, new_beam_size, ...).
4646
"""
47-
# Extract dimensions from tensor shapes
4847
_, old_beam_size = x.shape[:2]
4948

5049
# Create one-hot mask for the selection
@@ -72,35 +71,37 @@ def _gather_beams(x, beam_indices):
7271
),
7372
)
7473
def generic_beam_search_jax(
75-
model,
76-
key,
77-
mask_fn,
78-
batch_size,
79-
beam_size,
80-
tokens_per_beam,
81-
max_sample_len,
82-
start_token=0,
83-
):
74+
model: Callable,
75+
key: jax.random.PRNGKey,
76+
mask_fn: Callable,
77+
batch_size: int,
78+
beam_size: int,
79+
tokens_per_beam: int,
80+
max_sample_len: int,
81+
start_token: int = 0,
82+
) -> jnp.ndarray:
8483
"""A framework-agnostic harness for constrained beam search benchmarks.
8584
86-
This function executes a standard autoregressive decoding loop but applies
87-
a provided `mask_fn` at every step, including the initial root step. This
88-
allows for a fair comparison between different constraint enforcement
89-
algorithms (Trie, PPV, Hash) using identical selection and history management.
85+
Executes a standard autoregressive decoding loop but applies a provided
86+
`mask_fn` at every step, including the initial root step. This allows for a
87+
fair comparison between different constraint enforcement algorithms (Trie,
88+
PPV, Hash) using identical selection and history management.
9089
9190
Args:
92-
model (Callable): A callable (or object) that takes (input_ids, key) and
93-
returns (logits, new_key).
94-
key (jax.random.PRNGKey): PRNG key for mock logit generation.
95-
mask_fn (Callable): The baseline masking function to evaluate.
96-
batch_size (int): Number of parallel sequences.
97-
beam_size (int): Number of beams to maintain per sequence.
98-
tokens_per_beam (int): Number of candidate tokens to consider per beam.
99-
max_sample_len (int): The total decoding length (L).
100-
start_token (int): Token ID used to initiate decoding (BOS/PAD).
91+
model: A callable that takes (input_ids, key) and returns
92+
(logits, new_key).
93+
key: PRNG key for mock logit generation.
94+
mask_fn: The baseline masking function to evaluate. Must accept
95+
(logprobs, history, step) and return masked logprobs.
96+
batch_size: Number of parallel sequences (B).
97+
beam_size: Number of beams to maintain per sequence (M).
98+
tokens_per_beam: Number of candidate tokens to consider per beam.
99+
max_sample_len: The total decoding length (L).
100+
start_token: Token ID used to initiate decoding (BOS/PAD).
101101
102102
Returns:
103-
jnp.ndarray: The decoded sequences of shape (batch_size, beam_size, L).
103+
The decoded sequences.
104+
Shape: (batch_size, beam_size, max_sample_len).
104105
"""
105106
# --- 1. INITIAL STEP (Root) ---
106107
# Create start tokens to prime the model
@@ -175,8 +176,16 @@ def generic_beam_search_jax(
175176
# =============================================================================
176177

177178

178-
def build_trie(sids):
179-
"""Constructs a standard nested dictionary prefix tree on the CPU."""
179+
def build_trie(sids: np.ndarray) -> dict:
180+
"""Constructs a standard nested dictionary prefix tree on the CPU.
181+
182+
Args:
183+
sids: Array of Semantic IDs.
184+
Shape: (N, L).
185+
186+
Returns:
187+
The root node of the trie (nested dictionaries).
188+
"""
180189
trie = {}
181190
for sid in sids:
182191
node = trie
@@ -188,16 +197,26 @@ def build_trie(sids):
188197
return trie
189198

190199

191-
def make_trie_mask_fn(trie_root, vocab_size):
200+
def make_trie_mask_fn(
201+
trie_root: dict, vocab_size: int
202+
) -> Callable:
192203
"""Creates a masking function that calls back to a CPU-based Trie.
193204
194205
This baseline simulates the "pointer-chasing" behavior common in many
195206
production systems. It uses `jax.pure_callback` to pause accelerator
196207
execution and retrieve a validity mask from CPU memory at every step.
208+
209+
Args:
210+
trie_root: The root node of the CPU trie.
211+
vocab_size: The token vocabulary size (V).
212+
213+
Returns:
214+
A JAX-compatible masking function accepting
215+
(logprobs, token_buffer, step).
197216
"""
198217
all_tokens_set = set(range(vocab_size))
199218

200-
def python_callback(beams, step):
219+
def python_callback(beams: np.ndarray, step: int) -> np.ndarray:
201220
"""Internal Python logic to traverse the dictionary trie."""
202221
n = beams.shape[0]
203222
masks = np.zeros((n, vocab_size), dtype=bool)
@@ -244,8 +263,17 @@ def mask_fn(logprobs, token_buffer, step):
244263
# =============================================================================
245264

246265

247-
def build_hash_bitmap(sids):
248-
"""Creates a static hash bitmap (Bloom filter style) for valid prefixes."""
266+
def build_hash_bitmap(sids: np.ndarray) -> jnp.ndarray:
267+
"""Creates a static hash bitmap (Bloom filter style) for valid prefixes.
268+
269+
Args:
270+
sids: Array of Semantic IDs.
271+
Shape: (N, L).
272+
273+
Returns:
274+
The bitmap array on device.
275+
Shape: (SIZE // 8,) where SIZE = 2^30.
276+
"""
249277
BITMAP_BITS = 30
250278
SIZE = 1 << BITMAP_BITS
251279
MULTIPLIER = 0x9E371
@@ -266,11 +294,19 @@ def build_hash_bitmap(sids):
266294
return jnp.array(bitmap)
267295

268296

269-
def make_hash_bitmap_fn(bitmap):
297+
def make_hash_bitmap_fn(bitmap: jnp.ndarray) -> Callable:
270298
"""Creates a hash-based masking function optimized for JAX.
271299
272300
This baseline is accelerator-native but suffers from potential false
273301
positives and lacks the O(1) candidate selection of the STATIC kernel.
302+
303+
Args:
304+
bitmap: The precomputed hash bitmap.
305+
Shape: (SIZE // 8,).
306+
307+
Returns:
308+
A JAX-compatible masking function accepting
309+
(logprobs, token_buffer, step).
274310
"""
275311
BITMAP_BITS = 30
276312
SIZE = 1 << BITMAP_BITS
@@ -312,12 +348,32 @@ def hash_prefix(seq):
312348

313349

314350
@functools.partial(jit, static_argnames=["M", "step"])
315-
def ppv_batch_logic(flat_logprobs, history, step, sorted_sids, M):
351+
def ppv_batch_logic(
352+
flat_logprobs: jnp.ndarray,
353+
history: jnp.ndarray,
354+
step: int,
355+
sorted_sids: jnp.ndarray,
356+
M: int,
357+
) -> jnp.ndarray:
316358
"""Implements the PPV (Parallel Prefix Verification) algorithm.
317359
318360
PPV performs binary search across a sorted list of Semantic IDs to validate
319361
individual candidate extensions. This baseline corresponds to the method
320362
described in Ye et al. [30].
363+
364+
Args:
365+
flat_logprobs: Log-probabilities for the current step.
366+
Shape: (batch_size, vocab_size).
367+
history: Token history for each beam.
368+
Shape: (batch_size, max_sample_len).
369+
step: The current decoding step index.
370+
sorted_sids: The sorted constraint set.
371+
Shape: (N, L).
372+
M: Number of top-k candidates to verify.
373+
374+
Returns:
375+
Boolean validity mask.
376+
Shape: (batch_size, vocab_size).
321377
"""
322378
batch_size = flat_logprobs.shape[0]
323379
N = sorted_sids.shape[0]
@@ -391,8 +447,20 @@ def body_p2(state):
391447
return result_mask.at[jnp.arange(batch_size)[:, None], vt].max(final_valid)
392448

393449

394-
def make_ppv_mask_fn(sorted_sids, top_k):
395-
"""Creates a PPV-based masking function."""
450+
def make_ppv_mask_fn(
451+
sorted_sids: jnp.ndarray, top_k: int
452+
) -> Callable:
453+
"""Creates a PPV-based masking function.
454+
455+
Args:
456+
sorted_sids: The sorted constraint set on device.
457+
Shape: (N, L).
458+
top_k: Number of top candidates to verify per beam.
459+
460+
Returns:
461+
A JAX-compatible masking function accepting
462+
(logprobs, token_buffer, step).
463+
"""
396464

397465
def mask_fn(flat_logprobs, token_buffer, step):
398466
batch_size = flat_logprobs.shape[0]

benchmarks/run_branch_benchmark_jax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def run_real_csr_benchmark_oss(num_sequences=1_000_000, batch_beam=2, l_sid=8):
5959

6060
# Build STATIC Index (We only need the CSR components for the masking kernel)
6161
packed_csr_np, indptr_np, _, _, _, _ = build_static_index(
62-
sids_np, vocab_size, d=1
62+
sids_np, vocab_size, dense_lookup_layers=1
6363
)
6464

6565
# Move Index to Accelerator Memory (HBM)

benchmarks/run_branch_benchmark_pt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def run_real_csr_benchmark_gpu(num_sequences=1_000_000, batch_beam=2, l_sid=8):
5959

6060
# Build STATIC Index (We only need the CSR components for the masking kernel)
6161
packed_csr_np, indptr_np, _, _, _, _ = build_static_index(
62-
sids_np, vocab_size, d=1
62+
sids_np, vocab_size, dense_lookup_layers=1
6363
)
6464

6565
# Move Index to Device

benchmarks/run_comparative_benchmark_jax.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -89,19 +89,19 @@ def run_benchmarks():
8989
if method == "STATIC":
9090
if "static" not in cache["structs"]:
9191
# Build STATIC Index (d=2 dense specialization)
92-
p_csr, indptr, lmb, s_m, d_m, d_s = build_static_index(
93-
cache["sids_np"], v, d=2
92+
p_csr, indptr, layer_max_branches, s_m, d_m, d_s = build_static_index(
93+
cache["sids_np"], v, dense_lookup_layers=2
9494
)
9595
cache["structs"]["static"] = (
9696
jnp.array(p_csr),
9797
jnp.array(indptr),
98-
lmb,
98+
layer_max_branches,
9999
jnp.array(s_m),
100100
jnp.array(d_m),
101101
jnp.array(d_s),
102102
)
103103

104-
packed_csr, indptr, lmb, start_mask, dense_mask, dense_states = cache[
104+
packed_csr, indptr, layer_max_branches, start_mask, dense_mask, dense_states = cache[
105105
"structs"
106106
]["static"]
107107

@@ -115,7 +115,7 @@ def run_benchmarks():
115115
0, # start_token
116116
SID_LEN,
117117
v,
118-
lmb,
118+
layer_max_branches,
119119
packed_csr,
120120
indptr,
121121
start_mask,
@@ -135,7 +135,7 @@ def run_benchmarks():
135135
0, # start_token
136136
SID_LEN,
137137
v,
138-
lmb,
138+
layer_max_branches,
139139
packed_csr,
140140
indptr,
141141
start_mask,

0 commit comments

Comments
 (0)