-
Notifications
You must be signed in to change notification settings - Fork 931
Expand file tree
/
Copy pathgdn_prefill_sm120_kernel_inst.jinja
More file actions
39 lines (34 loc) · 1.91 KB
/
gdn_prefill_sm120_kernel_inst.jinja
File metadata and controls
39 lines (34 loc) · 1.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
/*
* Copyright (c) 2025 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Auto-generated file for separate compilation of GDN prefill kernel variants.
// Template parameters: dtype={{ dtype }}, is_gva={{ is_gva }}, needs_beta={{ needs_beta }},
// needs_alpha={{ needs_alpha }}, init_state={{ init_state }},
// enable_checkpointing={{ enable_checkpointing }}
// CUDA type definitions for half and nv_bfloat16
#include <cuda_bf16.h>
#include <cuda_fp16.h>
// Include the header which defines the function template
// The header includes all necessary CUTLASS type definitions
#include "flashinfer/flat/prefill/prefill_kernel_delta_rule_sm120.cuh"
namespace flat {
// Explicit template instantiation for launch_delta_rule_prefill_kernel_gbai
// Parameter types must exactly match the extern template declaration in prefill_kernel_delta_rule_sm120_extern.inc
template void launch_delta_rule_prefill_kernel_gbai<{{ is_gva }}, {{ needs_beta }}, {{ needs_alpha }}, {{ init_state }}, {{ enable_checkpointing }}, cutlass::arch::Sm120, {{ dtype }}, {{ dtype }}, float>(
cudaStream_t, {{ dtype }}*, float*, {{ dtype }} const*, {{ dtype }} const*, {{ dtype }} const*,
float const*, float const*, float const*, int64_t const*, uint8_t*, int32_t, int32_t,
int32_t, int32_t, int32_t, int32_t, int64_t, float, int32_t, float*, int64_t const*,
int32_t);
} // namespace flat