Skip to content

Commit bc77064

Browse files
authored
fix: advance sampler RNG each decode step (#940)
* fix: advance sampler RNG each decode step to enable true random sampling
1 parent b54fa0e commit bc77064

3 files changed

Lines changed: 32 additions & 16 deletions

File tree

python/sgl_jax/srt/layers/sampler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ def __call__(
163163
logits_output: LogitsProcessorOutput,
164164
sampling_metadata: SamplingMetadata,
165165
use_sort_for_toppk_minp: bool,
166+
rng_override: jax.Array | None = None,
166167
):
167168
"""Run a sampler & compute logprobs and update logits_output accordingly.
168169
@@ -188,7 +189,7 @@ def __call__(
188189
(logits, sampling_metadata.vocab_mask),
189190
)
190191

191-
_, rng = jax.random.split(self.rngs.params())
192+
_, rng = jax.random.split(rng_override if rng_override is not None else self.rngs.params())
192193
operands = (logits, sampling_metadata, rng)
193194
regular_fn = lambda op: self._regular_sampling((*op, use_sort_for_toppk_minp))
194195
batch_next_token_ids, logprobs = lax.cond(

python/sgl_jax/srt/model_executor/model_runner.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,8 @@ def initialize(self):
151151
self.init_lora_manager()
152152

153153
if not self.is_draft_worker:
154+
self._sampler_base_rng = jax.random.PRNGKey(server_args.random_seed)
155+
self._sampler_step = 0
154156
self.initialize_jit()
155157

156158
# Init memory pool and attention backends
@@ -220,18 +222,27 @@ def jitted_run_model(
220222
with LoraBatchContext.set_batch(forward_batch):
221223
return model(forward_batch, token_to_kv_pool, logits_metadata)
222224

225+
# Capture base RNG key as a constant in the JIT closure.
226+
# fold_in(constant, dynamic_step) is computed inside JIT, avoiding
227+
# the eager jax.random.split that would serialize the host-device pipeline.
228+
base_rng_key = self._sampler_base_rng
229+
223230
@partial(jax.jit, static_argnames=["sampler_state_def", "use_sort_for_toppk_minp"])
224231
def jitted_sampler(
225232
sampler_def,
226233
sampler_state_def,
227234
sampler_state_leaves,
228235
use_sort_for_toppk_minp,
236+
rng_step,
229237
*args,
230238
):
231239

232240
model_state = jax.tree_util.tree_unflatten(sampler_state_def, sampler_state_leaves)
233241
sampler = nnx.merge(sampler_def, model_state)
234-
return sampler(*args, use_sort_for_toppk_minp=use_sort_for_toppk_minp)
242+
rng_key = jax.random.fold_in(base_rng_key, rng_step)
243+
return sampler(
244+
*args, use_sort_for_toppk_minp=use_sort_for_toppk_minp, rng_override=rng_key
245+
)
235246

236247
@partial(jax.jit, static_argnames=["mesh"])
237248
def jitted_compute_logprobs(mesh, logits, next_tokens):
@@ -728,8 +739,12 @@ def sample(
728739
Returns:
729740
A list of next_token_ids
730741
"""
742+
# Advance step counter (pure Python, zero device overhead).
743+
# fold_in(base_key, step) inside JIT produces a unique RNG per step.
744+
self._sampler_step += 1
731745
# Penalty application has been moved to the Sampler for better JIT performance
732746
return self.jitted_sampler(
747+
self._sampler_step,
733748
logits_output,
734749
sampling_metadata,
735750
)

test/srt/test_logprobs.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -138,12 +138,6 @@ def test_logprobs(self):
138138

139139
sampling_params = {"n": 1, "temperature": 0.6, "top_p": 0.95, "max_new_tokens": 3}
140140

141-
expected_output_logprobs = [
142-
[-0.8984375, 71486, "Alright"], ## todo use output compute is -0.79296875
143-
[0.0, 11, ","],
144-
[-0.06787109375, 279, " the"],
145-
]
146-
147141
output = self.engine.generate(
148142
input_ids=input_ids,
149143
sampling_params=sampling_params,
@@ -153,22 +147,28 @@ def test_logprobs(self):
153147
token_ids_logprob=token_ids_logprob,
154148
)
155149
output_meta = output["meta_info"]
156-
self.check_output(output_meta, "output_token_logprobs", expected_output_logprobs)
150+
# With temperature>0 sampling, exact tokens depend on RNG state.
151+
# Only verify structural correctness here.
152+
self.assertEqual(
153+
len(output_meta["output_token_logprobs"]),
154+
3,
155+
"output_token_logprobs length mismatch",
156+
)
157+
for i, logprob in enumerate(output_meta["output_token_logprobs"]):
158+
self.assertLessEqual(logprob[0], 0.0, f"logprob at {i} should be non-positive")
157159

158-
# use another expected, because jax compiler fused ops will introduce numerical precision issue
159-
expected_output_logprobs = [
160-
[-0.78125, 32313, "Okay"], # todo use output compute is -0.79296875
161-
[0.0, 11, ","],
162-
[-0.1650390625, 773, " so"],
163-
]
164160
output = self.engine.generate(
165161
input_ids=input_ids,
166162
sampling_params=sampling_params,
167163
return_logprob=True,
168164
)
169165
output_meta = output["meta_info"]
170166
self.assertEqual(output_meta["cache_miss_count"], 0, "occur cache_miss")
171-
self.check_output(output_meta, "output_token_logprobs", expected_output_logprobs)
167+
self.assertEqual(
168+
len(output_meta["output_token_logprobs"]),
169+
3,
170+
"output_token_logprobs length mismatch",
171+
)
172172

173173
def check_output(self, actual, key, expected):
174174
for i, logprob in enumerate(actual[key]):

0 commit comments

Comments
 (0)