Skip to content

Commit 50fdbe2

Browse files
committed
[levanter] Fix top-p cutoff boundary
1 parent aa42218 commit 50fdbe2

2 files changed

Lines changed: 27 additions & 2 deletions

File tree

lib/levanter/src/levanter/layers/sampler.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,12 @@ def _apply_top_p(
114114
top_p_array = top_ps
115115
threshold = jnp.clip(jnp.asarray(top_p_array, dtype=jnp.float32), min=0.0, max=1.0)[..., None]
116116

117-
keep_sorted = cumulative_probs <= threshold
117+
# Keep the smallest prefix whose cumulative mass reaches the threshold.
118+
# The cutoff-crossing token should remain eligible, but we should not
119+
# include the next token when the threshold is met exactly.
120+
keep_sorted = cumulative_probs < threshold
118121
keep_sorted = jnp.concatenate(
119-
[jnp.ones_like(keep_sorted[..., :1], dtype=bool), keep_sorted[..., 1:]],
122+
[jnp.ones_like(keep_sorted[..., :1], dtype=bool), keep_sorted[..., :-1]],
120123
axis=-1,
121124
)
122125
filtered_sorted_logits = jnp.where(keep_sorted, sorted_logits, -jnp.inf)

lib/levanter/tests/test_sampler.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,25 @@ def test_sampler_top_p_keeps_only_the_nucleus_head():
2323

2424
assert int(token.array) == 0
2525
assert float(log_prob.array) == pytest.approx(0.0)
26+
27+
28+
def test_sampler_top_p_keeps_cutoff_crossing_token():
29+
vocab = hax.Axis("vocab", 3)
30+
sampler = Sampler(vocab)
31+
logits = hax.named(jnp.log(jnp.array([0.4, 0.35, 0.25], dtype=jnp.float32)), (vocab,))
32+
33+
masked_logits = sampler._apply_top_p(logits, jnp.array(0.6, dtype=jnp.float32))
34+
35+
assert jnp.isfinite(masked_logits.array[:2]).all()
36+
assert jnp.isneginf(masked_logits.array[2])
37+
38+
39+
def test_sampler_top_p_does_not_overshoot_exact_threshold():
40+
vocab = hax.Axis("vocab", 3)
41+
sampler = Sampler(vocab)
42+
logits = hax.named(jnp.log(jnp.array([0.4, 0.35, 0.25], dtype=jnp.float32)), (vocab,))
43+
44+
masked_logits = sampler._apply_top_p(logits, jnp.array(0.4, dtype=jnp.float32))
45+
46+
assert jnp.isfinite(masked_logits.array[0])
47+
assert jnp.isneginf(masked_logits.array[1:]).all()

0 commit comments

Comments
 (0)