Skip to content

Commit a9bdd67

Browse files
authored
Add CUDA sdpa vector (#2468)
1 parent f2adb56 commit a9bdd67

File tree

3 files changed

+782
-12
lines changed

3 files changed

+782
-12
lines changed

mlx/backend/cuda/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ target_sources(
3939
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
4040
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
4141
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
42+
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cu
4243
${CMAKE_CURRENT_SOURCE_DIR}/scan.cu
4344
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
4445
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu

mlx/backend/cuda/primitives.cpp

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,6 @@
66

77
namespace mlx::core {
88

9-
bool fast::ScaledDotProductAttention::use_fallback(
10-
const array& q,
11-
const array& k,
12-
const array& v,
13-
bool has_mask,
14-
bool has_arr_mask,
15-
bool do_causal,
16-
Stream s) {
17-
return true;
18-
}
19-
209
#define NO_GPU_MULTI(func) \
2110
void func::eval_gpu( \
2211
const std::vector<array>& inputs, std::vector<array>& outputs) { \
@@ -53,7 +42,6 @@ NO_GPU_MULTI(Eig)
5342
NO_GPU_MULTI(Eigh)
5443

5544
namespace fast {
56-
NO_GPU(ScaledDotProductAttention)
5745
NO_GPU_MULTI(CustomKernel)
5846
} // namespace fast
5947

0 commit comments

Comments
 (0)