Commit 387359e
Enable get_last_token for batch>1 prefill, single-call paged_fill_cache
- model.py: When batch_size>1, extract each user's 32-token last-token
tile via ttnn.slice and concat to [1,1,B*32,H] before norm+lm_head.
Removes the batch>1 get_last_token=-1 override in ttnn_prefill_forward.
- attention/prefill.py: Replace per-user paged_fill_cache loop with
single-call reshape approach (flatten batch into seq dim, heads into
last dim, flatten page table). Matches llama_70b_galaxy pattern.
- text_demo.py: Remove batch>1 get_last_token override, use
seq_len_per_user=32 when get_last_token is active.
batch128 users_per_row_per_iter=2 TTFT: 218ms -> 99ms (get_last_token fix)
Compile time: 5.24s -> 3.16s (single-call paged_fill_cache)
batch128 users_per_row_per_iter=1: unchanged at 91ms1 parent 2c88901 commit 387359e
File tree
3 files changed
+45
-27
lines changed- models/demos/gpt_oss
- demo
- tt
- attention
3 files changed
+45
-27
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
599 | 599 | | |
600 | 600 | | |
601 | 601 | | |
602 | | - | |
603 | | - | |
604 | | - | |
605 | | - | |
606 | | - | |
607 | | - | |
608 | | - | |
609 | | - | |
610 | | - | |
611 | | - | |
612 | | - | |
| 602 | + | |
| 603 | + | |
| 604 | + | |
| 605 | + | |
| 606 | + | |
| 607 | + | |
| 608 | + | |
| 609 | + | |
613 | 610 | | |
614 | 611 | | |
615 | 612 | | |
| |||
669 | 666 | | |
670 | 667 | | |
671 | 668 | | |
672 | | - | |
| 669 | + | |
673 | 670 | | |
674 | 671 | | |
675 | 672 | | |
| |||
753 | 750 | | |
754 | 751 | | |
755 | 752 | | |
756 | | - | |
| 753 | + | |
757 | 754 | | |
758 | 755 | | |
759 | 756 | | |
| |||
804 | 801 | | |
805 | 802 | | |
806 | 803 | | |
807 | | - | |
| 804 | + | |
| 805 | + | |
| 806 | + | |
| 807 | + | |
808 | 808 | | |
809 | 809 | | |
810 | 810 | | |
| |||
818 | 818 | | |
819 | 819 | | |
820 | 820 | | |
| 821 | + | |
821 | 822 | | |
822 | 823 | | |
| 824 | + | |
823 | 825 | | |
824 | 826 | | |
825 | 827 | | |
826 | 828 | | |
827 | | - | |
| 829 | + | |
828 | 830 | | |
829 | 831 | | |
830 | 832 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
99 | 99 | | |
100 | 100 | | |
101 | 101 | | |
102 | | - | |
103 | | - | |
104 | | - | |
105 | | - | |
106 | | - | |
107 | | - | |
108 | | - | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
109 | 111 | | |
110 | 112 | | |
111 | 113 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
278 | 278 | | |
279 | 279 | | |
280 | 280 | | |
281 | | - | |
282 | 281 | | |
283 | 282 | | |
284 | | - | |
285 | | - | |
286 | | - | |
| 283 | + | |
| 284 | + | |
| 285 | + | |
| 286 | + | |
| 287 | + | |
| 288 | + | |
| 289 | + | |
| 290 | + | |
| 291 | + | |
| 292 | + | |
| 293 | + | |
| 294 | + | |
| 295 | + | |
| 296 | + | |
| 297 | + | |
| 298 | + | |
| 299 | + | |
| 300 | + | |
287 | 301 | | |
288 | 302 | | |
289 | 303 | | |
| |||
366 | 380 | | |
367 | 381 | | |
368 | 382 | | |
369 | | - | |
| 383 | + | |
370 | 384 | | |
371 | 385 | | |
372 | 386 | | |
| |||
0 commit comments