Skip to content

Commit d5058d9

Browse files
JamesBrianDclaude
andcommitted
fix: guard device_get with hasattr check for already-converted lists
next_token_ids may already be a Python list (e.g. from overlap path or spec decoding). Check for .device attribute before calling jax.device_get(). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent c5f6f4e commit d5058d9

1 file changed

Lines changed: 4 additions & 2 deletions

File tree

python/sgl_jax/srt/managers/scheduler_output_processor_mixin.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,8 @@ def process_batch_result_prefill(
7777
logits_output = self._gather_logits_output(logits_output)
7878

7979
# Move next_token_ids to cpu as Python ints
80-
next_token_ids = jax.device_get(next_token_ids).tolist()
80+
if hasattr(next_token_ids, "device"):
81+
next_token_ids = jax.device_get(next_token_ids).tolist()
8182

8283
# Move next_token_ids and logprobs to cpu
8384
if batch.return_output_logprob_only and logits_output.next_token_logprobs is not None:
@@ -312,7 +313,8 @@ def process_batch_result_decode(
312313
logits_output = self._gather_logits_output(logits_output)
313314

314315
# Move next_token_ids to cpu as Python ints
315-
next_token_ids = jax.device_get(next_token_ids).tolist()
316+
if hasattr(next_token_ids, "device"):
317+
next_token_ids = jax.device_get(next_token_ids).tolist()
316318

317319
# spec decoding handles output logprobs inside verify process.
318320
if batch.return_logprob or batch.return_output_logprob_only:

0 commit comments

Comments
 (0)