|
29 | 29 | namespace flat { |
30 | 30 |
|
31 | 31 | // Explicit template instantiation for launch_delta_rule_prefill_kernel_gbai |
32 | | -template void launch_delta_rule_prefill_kernel_gbai< |
33 | | - /*IsGVA=*/{{ is_gva }}, |
34 | | - /*NeedsBeta=*/{{ needs_beta }}, |
35 | | - /*NeedsAlpha=*/{{ needs_alpha }}, |
36 | | - /*InitStateFromInput=*/{{ init_state }}, |
37 | | - cutlass::arch::Sm90, |
38 | | - {{ dtype }}, {{ dtype }}, float>( |
39 | | - cudaStream_t stream, {{ dtype }}* output, float* output_state, |
40 | | - {{ dtype }} const* q, {{ dtype }} const* k, {{ dtype }} const* v, |
41 | | - float const* input_state, float const* alpha, float const* beta, |
42 | | - int64_t const* cu_seqlens, int32_t num_seqs, int32_t num_q_heads, |
43 | | - int32_t num_k_heads, int32_t num_v_heads, int32_t num_o_heads, |
44 | | - int32_t head_size, int64_t total_seqlen, float scale, int32_t sm_count); |
| 32 | +// Parameter types must exactly match the extern template declaration in prefill_kernel_delta_rule_sm90_extern.inc |
| 33 | +template void launch_delta_rule_prefill_kernel_gbai<{{ is_gva }}, {{ needs_beta }}, {{ needs_alpha }}, {{ init_state }}, cutlass::arch::Sm90, {{ dtype }}, {{ dtype }}, float>( |
| 34 | + cudaStream_t, {{ dtype }}*, float*, {{ dtype }} const*, {{ dtype }} const*, {{ dtype }} const*, |
| 35 | + float const*, float const*, float const*, int64_t const*, int32_t, int32_t, |
| 36 | + int32_t, int32_t, int32_t, int32_t, int64_t, float, int32_t); |
45 | 37 |
|
46 | 38 | } // namespace flat |
0 commit comments