Skip to content

Commit 1322d60

Browse files
committed
refactor
1 parent 3e313ad commit 1322d60

26 files changed

Lines changed: 31 additions & 39 deletions

csrc/flat/prefill/prefill_kernel_delta_rule_sm90.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515
*/
1616
#include <cuda_bf16.h>
1717

18-
#include "prefill_kernel_delta_rule_sm90.cuh"
18+
#include "flashinfer/flat/prefill/prefill_kernel_delta_rule_sm90.cuh"
1919

2020
// Extern template declarations prevent implicit instantiation here.
2121
// Explicit instantiations are in separate generated files for parallel compilation.
22-
#include "prefill_kernel_delta_rule_sm90_extern.inc"
22+
#include "flat_prefill_kernel_delta_rule_sm90_extern.inc"
2323

2424
namespace flat {
2525

csrc/flat/prefill/prefill_kernel_delta_rule_sm90_extern.inc renamed to csrc/flat_prefill_kernel_delta_rule_sm90_extern.inc

File renamed without changes.

csrc/gdn_prefill_launcher.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
#include <iostream>
2626
#include <sstream>
2727

28-
#include "flat/prefill/prefill_kernel.hpp"
28+
#include "flashinfer/flat/prefill/prefill_kernel.hpp"
2929

3030
using tvm::ffi::Optional;
3131
using tvm::ffi::TensorView;

csrc/gdn_prefill_sm90_kernel_inst.jinja

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
// Include the header which defines the function template
2525
// The header includes all necessary CUTLASS type definitions
26-
#include "prefill_kernel_delta_rule_sm90.cuh"
26+
#include "flashinfer/flat/prefill/prefill_kernel_delta_rule_sm90.cuh"
2727

2828
namespace flat {
2929

flashinfer/jit/gdn.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ def gen_gdn_prefill_sm90_module() -> JitSpec:
6565
)
6666
write_if_different(dest_path, source)
6767

68-
# Copy source files to gen_directory (like POD module does)
68+
# Copy source files to gen_directory for compilation
69+
# Headers are now in include/flashinfer/flat/ and accessible via standard include paths
6970
for filename in [
7071
"gdn_prefill_launcher.cu",
7172
"flat/prefill/prefill_kernel_delta_rule_sm90.cu",
@@ -75,17 +76,8 @@ def gen_gdn_prefill_sm90_module() -> JitSpec:
7576
source_paths.append(dest_path)
7677
write_if_different(dest_path, src_path.read_text())
7778

78-
# Copy header files so relative includes work
79-
for filename in [
80-
"flat/prefill/prefill_kernel_delta_rule_sm90.cuh",
81-
"flat/prefill/prefill_kernel_delta_rule_sm90_extern.inc",
82-
]:
83-
src_path = jit_env.FLASHINFER_CSRC_DIR / filename
84-
write_if_different(gen_directory / src_path.name, src_path.read_text())
85-
8679
return gen_jit_spec(
8780
uri,
8881
source_paths,
8982
extra_cuda_cflags=sm90a_nvcc_flags + ["-DFLAT_SM90A_ENABLED", "-std=c++20"],
90-
extra_include_paths=[gen_directory, jit_env.FLASHINFER_CSRC_DIR],
9183
)

csrc/flat/ampere/collective/flat_collective_inverse.hpp renamed to include/flashinfer/flat/ampere/collective/flat_collective_inverse.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
#include "cute/tensor.hpp"
1919
#include "cutlass/arch/barrier.h"
2020
#include "cutlass/cutlass.h"
21-
#include "flat/cute_ext.hpp"
21+
#include "flashinfer/flat/cute_ext.hpp"
2222

2323
namespace flat::collective {
2424

csrc/flat/ampere/collective/flat_collective_load.hpp renamed to include/flashinfer/flat/ampere/collective/flat_collective_load.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
#include "cute/tensor.hpp"
1919
#include "cutlass/cutlass.h"
2020
#include "cutlass/pipeline/sm90_pipeline.hpp"
21-
#include "flat/unused.hpp"
21+
#include "flashinfer/flat/unused.hpp"
2222

2323
namespace flat::collective {
2424

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
#include <stdexcept>
2020
#include <string>
2121

22-
#include "debug.hpp"
22+
#include "flashinfer/flat/debug.hpp"
2323

2424
#define FLAT_UNUSED_PARAMETER(x) (void)x
2525

0 commit comments

Comments
 (0)