Skip to content

Commit c74181f

Browse files
committed
fix
1 parent 37c254c commit c74181f

3 files changed

Lines changed: 2 additions & 5 deletions

File tree

csrc/flat_prefill_kernel_delta_rule_sm90_extern.inc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ namespace flat {
4747
#define DECLARE_TEMPLATE_INSTANCE(is_gva, needs_beta, needs_alpha, init_state, ctype) \
4848
extern template void launch_delta_rule_prefill_kernel_gbai<is_gva, needs_beta, needs_alpha, init_state, cutlass::arch::Sm90, ctype, ctype, float>( \
4949
cudaStream_t, ctype*, float*, ctype const*, ctype const*, ctype const*, \
50-
float const*, float const*, float const*, int64_t const*, int32_t, int32_t, \
50+
float const*, float const*, float const*, int64_t const*, uint8_t*, int32_t, int32_t, \
5151
int32_t, int32_t, int32_t, int32_t, int64_t, float, int32_t);
5252

5353
// Extern template declarations for half

csrc/gdn_prefill_sm90_kernel_inst.jinja

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ namespace flat {
3131
// Parameter types must exactly match the extern template declaration in prefill_kernel_delta_rule_sm90_extern.inc
3232
template void launch_delta_rule_prefill_kernel_gbai<{{ is_gva }}, {{ needs_beta }}, {{ needs_alpha }}, {{ init_state }}, cutlass::arch::Sm90, {{ dtype }}, {{ dtype }}, float>(
3333
cudaStream_t, {{ dtype }}*, float*, {{ dtype }} const*, {{ dtype }} const*, {{ dtype }} const*,
34-
float const*, float const*, float const*, int64_t const*, int32_t, int32_t,
34+
float const*, float const*, float const*, int64_t const*, uint8_t*, int32_t, int32_t,
3535
int32_t, int32_t, int32_t, int32_t, int64_t, float, int32_t);
3636

3737
} // namespace flat

flashinfer/aot.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -443,9 +443,6 @@ def gen_all_modules(
443443
add_misc: bool,
444444
add_xqa: bool,
445445
) -> List[JitSpec]:
446-
# TEMPORARY: Only compile gdn_prefill_sm90 for testing
447-
return [gen_gdn_prefill_sm90_module()]
448-
449446
jit_specs: List[JitSpec] = []
450447
jit_specs.append(gen_spdlog_module())
451448
has_sm90 = sm_capabilities.get("sm90", False)

0 commit comments

Comments
 (0)