Skip to content

Commit e56c621

Browse files
JamesBrianDclaude
andcommitted
fix: add device_get in _gather_next_token_ids to return Python list
eec2620 introduced _gather_next_token_ids which gathers sharded JAX arrays to replicated sharding, but did not convert the result to CPU. This left next_token_ids as a JAX on-device array, causing downstream unhashable type errors when token ids were used in set lookups (check_finished). Add device_get + tolist at the source. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 2ef0833 commit e56c621

1 file changed

Lines changed: 5 additions & 3 deletions

File tree

python/sgl_jax/srt/managers/scheduler.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -573,16 +573,18 @@ def _gather_logits_output(self, logits_output: LogitsProcessorOutput) -> LogitsP
573573

574574
return logits_output
575575

576-
def _gather_next_token_ids(self, next_token_ids: jax.Array) -> jax.Array:
577-
"""Gather sharded next_token_ids to replicated sharding."""
576+
def _gather_next_token_ids(self, next_token_ids: jax.Array) -> list[int]:
577+
"""Gather sharded next_token_ids to replicated sharding and convert to Python list."""
578578
from jax.sharding import NamedSharding, PartitionSpec
579579

580580
if next_token_ids is None:
581581
return None
582582

583583
replicated_sharding = NamedSharding(self.mesh, PartitionSpec())
584584
gather_fn = jax.jit(lambda x: x, out_shardings=replicated_sharding)
585-
return gather_fn(next_token_ids)
585+
gathered = gather_fn(next_token_ids)
586+
# Convert to Python list of ints for downstream compatibility
587+
return jax.device_get(gathered).tolist()
586588

587589
def _select_round_robin_dp(self) -> int:
588590
dp_rank = self.dp_round_robin_counter % self.dp_size

0 commit comments

Comments
 (0)