Commit e56c621
fix: add device_get in _gather_next_token_ids to return Python list
eec2620 introduced _gather_next_token_ids which gathers sharded JAX
arrays to replicated sharding, but did not convert the result to CPU.
This left next_token_ids as a JAX on-device array, causing downstream
unhashable type errors when token ids were used in set lookups
(check_finished). Add device_get + tolist at the source.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>1 parent 2ef0833 commit e56c621
1 file changed
Lines changed: 5 additions & 3 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
573 | 573 | | |
574 | 574 | | |
575 | 575 | | |
576 | | - | |
577 | | - | |
| 576 | + | |
| 577 | + | |
578 | 578 | | |
579 | 579 | | |
580 | 580 | | |
581 | 581 | | |
582 | 582 | | |
583 | 583 | | |
584 | 584 | | |
585 | | - | |
| 585 | + | |
| 586 | + | |
| 587 | + | |
586 | 588 | | |
587 | 589 | | |
588 | 590 | | |
| |||
0 commit comments