Skip to content

Commit d346333

Browse files
authored
[CUDA] Add build flag onnxruntime_USE_FPA_INTB_GEMM (#25802)
### Description Add a build flag to enable/disable mixed gemm cutlass kernel. To disable the kernel, you can append the following at the end of build command line: `--cmake_extra_defines onnxruntime_USE_FPA_INTB_GEMM=OFF` ### Motivation and Context FpA IntB Gemm need a lot of time to compile. With such option, developer can speed up the build especially on build machine with limited memory.
1 parent 9a2ea43 commit d346333

26 files changed

+96
-29
lines changed

cmake/CMakeLists.txt

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,8 @@ option(onnxruntime_USE_VSINPU "Build with VSINPU support" OFF)
9898

9999
cmake_dependent_option(onnxruntime_USE_FLASH_ATTENTION "Build flash attention kernel for scaled dot product attention" ON "onnxruntime_USE_CUDA" OFF)
100100
option(onnxruntime_USE_LEAN_ATTENTION "Build lean attention kernel for scaled dot product attention" OFF)
101-
option(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON)
101+
cmake_dependent_option(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON "onnxruntime_USE_CUDA" OFF)
102+
cmake_dependent_option(onnxruntime_USE_FPA_INTB_GEMM "Build FpA IntB gemm cuda kernels" ON "onnxruntime_USE_CUDA" OFF)
102103

103104
option(onnxruntime_BUILD_FOR_NATIVE_MACHINE "Enable this option for turning on optimization specific to this machine" OFF)
104105
option(onnxruntime_USE_AVX "Use AVX instructions" OFF)
@@ -696,6 +697,7 @@ if (onnxruntime_USE_CUDA)
696697
set(onnxruntime_USE_FLASH_ATTENTION OFF)
697698
set(onnxruntime_USE_LEAN_ATTENTION OFF)
698699
set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF)
700+
set(onnxruntime_USE_FPA_INTB_GEMM OFF)
699701
endif()
700702

701703
if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.6)
@@ -708,6 +710,11 @@ if (onnxruntime_USE_CUDA)
708710
set(onnxruntime_USE_FLASH_ATTENTION OFF)
709711
endif()
710712

713+
if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 12)
714+
message( STATUS "FpA IntB Gemm unsupported for CUDA compiler version < 12.0")
715+
set(onnxruntime_USE_FPA_INTB_GEMM OFF)
716+
endif()
717+
711718
if (WIN32)
712719
message( STATUS "Lean Attention unsupported in Windows")
713720
set(onnxruntime_USE_LEAN_ATTENTION OFF)
@@ -736,6 +743,11 @@ if (onnxruntime_USE_CUDA)
736743
message( STATUS "Enable memory efficient attention for CUDA EP")
737744
list(APPEND ORT_PROVIDER_FLAGS -DUSE_MEMORY_EFFICIENT_ATTENTION=1)
738745
endif()
746+
747+
if (onnxruntime_USE_FPA_INTB_GEMM)
748+
message( STATUS "Enable FpA IntB Gemm for CUDA EP")
749+
list(APPEND ORT_PROVIDER_FLAGS -DUSE_FPA_INTB_GEMM=1)
750+
endif()
739751
endif()
740752

741753
if (onnxruntime_USE_CUDA_INTERFACE AND (NOT onnxruntime_USE_CUDA))

onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/bf16_int4_gemm_scale_zeros.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
1515
*/
16-
16+
#if USE_FPA_INTB_GEMM
1717
#include "contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template.h"
1818

1919
namespace onnxruntime::llm {
@@ -24,3 +24,4 @@ template class CutlassFpAIntBGemmRunner<__nv_bfloat16, cutlass::uint4b_t,
2424
} // namespace cutlass_kernels
2525
} // namespace kernels
2626
} // namespace onnxruntime::llm
27+
#endif

onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/bf16_int4_gemm_scaleonly.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
1515
*/
16-
16+
#if USE_FPA_INTB_GEMM
1717
#include "contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template.h"
1818

1919
namespace onnxruntime::llm {
@@ -24,3 +24,4 @@ template class CutlassFpAIntBGemmRunner<__nv_bfloat16, cutlass::uint4b_t,
2424
} // namespace cutlass_kernels
2525
} // namespace kernels
2626
} // namespace onnxruntime::llm
27+
#endif

onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/bf16_int8_gemm_scale_zeros.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
1515
*/
16-
16+
#if USE_FPA_INTB_GEMM
1717
#include "contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template.h"
1818

1919
namespace onnxruntime::llm {
@@ -24,3 +24,4 @@ template class CutlassFpAIntBGemmRunner<__nv_bfloat16, uint8_t,
2424
} // namespace cutlass_kernels
2525
} // namespace kernels
2626
} // namespace onnxruntime::llm
27+
#endif

onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/bf16_int8_gemm_scaleonly.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
1515
*/
16-
16+
#if USE_FPA_INTB_GEMM
1717
#include "contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template.h"
1818

1919
namespace onnxruntime::llm {
@@ -23,3 +23,4 @@ template class CutlassFpAIntBGemmRunner<__nv_bfloat16, uint8_t, cutlass::WeightO
2323
} // namespace cutlass_kernels
2424
} // namespace kernels
2525
} // namespace onnxruntime::llm
26+
#endif

onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fp16_int4_gemm_scale_zeros.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
1515
*/
16-
16+
#if USE_FPA_INTB_GEMM
1717
#include "contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template.h"
1818

1919
namespace onnxruntime::llm {
@@ -24,3 +24,4 @@ template class CutlassFpAIntBGemmRunner<half, cutlass::uint4b_t,
2424
} // namespace cutlass_kernels
2525
} // namespace kernels
2626
} // namespace onnxruntime::llm
27+
#endif

onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fp16_int4_gemm_scaleonly.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
1515
*/
16-
16+
#if USE_FPA_INTB_GEMM
1717
#include "contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template.h"
1818

1919
namespace onnxruntime::llm {
@@ -23,3 +23,4 @@ template class CutlassFpAIntBGemmRunner<half, cutlass::uint4b_t, cutlass::Weight
2323
} // namespace cutlass_kernels
2424
} // namespace kernels
2525
} // namespace onnxruntime::llm
26+
#endif

onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fp16_int8_gemm_scale_zeros.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
1515
*/
16-
16+
#if USE_FPA_INTB_GEMM
1717
#include "contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template.h"
1818

1919
namespace onnxruntime::llm {
@@ -23,3 +23,4 @@ template class CutlassFpAIntBGemmRunner<half, uint8_t, cutlass::WeightOnlyQuantO
2323
} // namespace cutlass_kernels
2424
} // namespace kernels
2525
} // namespace onnxruntime::llm
26+
#endif

onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fp16_int8_gemm_scaleonly.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
1515
*/
16-
16+
#if USE_FPA_INTB_GEMM
1717
#include "contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template.h"
1818

1919
namespace onnxruntime::llm {
@@ -23,3 +23,4 @@ template class CutlassFpAIntBGemmRunner<half, uint8_t, cutlass::WeightOnlyQuantO
2323
} // namespace cutlass_kernels
2424
} // namespace kernels
2525
} // namespace onnxruntime::llm
26+
#endif

onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_gemm_launcher_1.generated.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11

2+
#if USE_FPA_INTB_GEMM
23
#ifndef EXCLUDE_SM_90
34
#include "contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.inl"
45

@@ -515,3 +516,4 @@ __nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::
515516
} // namespace kernels
516517
} // namespace onnxruntime::llm
517518
#endif // EXCLUDE_SM_90
519+
#endif // USE_FPA_INTB_GEMM

0 commit comments

Comments
 (0)