Skip to content

Commit 04345ce

Browse files
committed
minor fix
1 parent f33f6b8 commit 04345ce

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

examples/jax_tvm_ffi/gemma3_flashinfer_jax.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1011,10 +1011,10 @@
10111011
"# 1. Compile kernels (once)\n",
10121012
"gelu_module = gen_act_and_mul_module('gelu_tanh').build_and_load()\n",
10131013
"rope_module = gen_rope_module().build_and_load()\n",
1014+
"decode_local_module = gen_single_decode_module(..., use_sliding_window=True, ...).build_and_load()\n",
1015+
"decode_global_module = gen_single_decode_module(..., use_sliding_window=False, ...).build_and_load()\n",
10141016
"prefill_local = gen_single_prefill_module('fa2', ..., use_sliding_window=True).build_and_load()\n",
10151017
"prefill_global = gen_single_prefill_module('fa2', ..., use_sliding_window=False).build_and_load()\n",
1016-
"decode_local_module = gen_decode_jit_spec('bfloat16', 256, use_sliding_window=True).build_and_load()\n",
1017-
"decode_global_module = gen_decode_jit_spec('bfloat16', 256, use_sliding_window=False).build_and_load()\n",
10181018
"\n",
10191019
"# 2. Prefill: FlashInfer causal attention over all prompt tokens → KV-cache\n",
10201020
"h_last, kv_caches = prefill(prompt_ids)\n",

0 commit comments

Comments
 (0)