-
Notifications
You must be signed in to change notification settings - Fork 931
Expand file tree
/
Copy pathflat_prefill_kernel_delta_rule_sm120_extern.inc
More file actions
81 lines (70 loc) · 3.53 KB
/
flat_prefill_kernel_delta_rule_sm120_extern.inc
File metadata and controls
81 lines (70 loc) · 3.53 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
/*
* Copyright (c) 2025 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Extern template declarations to prevent implicit instantiation in the dispatcher.
// Explicit instantiations are in separate generated files for parallel compilation.
#pragma once
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include "cutlass/arch/arch.h"
namespace flat {
// clang-format off
#define FOR_EACH_BOOL_5(MACRO, ...) \
MACRO(false, false, false, false, false, __VA_ARGS__) \
MACRO(false, false, false, false, true, __VA_ARGS__) \
MACRO(false, false, false, true, false, __VA_ARGS__) \
MACRO(false, false, false, true, true, __VA_ARGS__) \
MACRO(false, false, true, false, false, __VA_ARGS__) \
MACRO(false, false, true, false, true, __VA_ARGS__) \
MACRO(false, false, true, true, false, __VA_ARGS__) \
MACRO(false, false, true, true, true, __VA_ARGS__) \
MACRO(false, true, false, false, false, __VA_ARGS__) \
MACRO(false, true, false, false, true, __VA_ARGS__) \
MACRO(false, true, false, true, false, __VA_ARGS__) \
MACRO(false, true, false, true, true, __VA_ARGS__) \
MACRO(false, true, true, false, false, __VA_ARGS__) \
MACRO(false, true, true, false, true, __VA_ARGS__) \
MACRO(false, true, true, true, false, __VA_ARGS__) \
MACRO(false, true, true, true, true, __VA_ARGS__) \
MACRO(true, false, false, false, false, __VA_ARGS__) \
MACRO(true, false, false, false, true, __VA_ARGS__) \
MACRO(true, false, false, true, false, __VA_ARGS__) \
MACRO(true, false, false, true, true, __VA_ARGS__) \
MACRO(true, false, true, false, false, __VA_ARGS__) \
MACRO(true, false, true, false, true, __VA_ARGS__) \
MACRO(true, false, true, true, false, __VA_ARGS__) \
MACRO(true, false, true, true, true, __VA_ARGS__) \
MACRO(true, true, false, false, false, __VA_ARGS__) \
MACRO(true, true, false, false, true, __VA_ARGS__) \
MACRO(true, true, false, true, false, __VA_ARGS__) \
MACRO(true, true, false, true, true, __VA_ARGS__) \
MACRO(true, true, true, false, false, __VA_ARGS__) \
MACRO(true, true, true, false, true, __VA_ARGS__) \
MACRO(true, true, true, true, false, __VA_ARGS__) \
MACRO(true, true, true, true, true, __VA_ARGS__)
#define DECLARE_TEMPLATE_INSTANCE(is_gva, needs_beta, needs_alpha, init_state, enable_ckpt, ctype) \
extern template void launch_delta_rule_prefill_kernel_gbai<is_gva, needs_beta, needs_alpha, init_state, enable_ckpt, cutlass::arch::Sm120, ctype, ctype, float>( \
cudaStream_t, ctype*, float*, ctype const*, ctype const*, ctype const*, \
float const*, float const*, float const*, int64_t const*, uint8_t*, int32_t, int32_t, \
int32_t, int32_t, int32_t, int32_t, int64_t, float, int32_t, float*, int64_t const*, \
int32_t);
// Extern template declarations for half
FOR_EACH_BOOL_5(DECLARE_TEMPLATE_INSTANCE, half)
// Extern template declarations for nv_bfloat16
FOR_EACH_BOOL_5(DECLARE_TEMPLATE_INSTANCE, nv_bfloat16)
#undef DECLARE_TEMPLATE_INSTANCE
#undef FOR_EACH_BOOL_5
// clang-format on
} // namespace flat