Skip to content

Commit ec8a486

Browse files
authored
Fix SDPA kernel bug on Mac OS 13.3 SDK (#805)
* Move sdpa kernel to allocate tgp mem statically and allow macOS 13.3 SDK builds * Style
1 parent b7588fd commit ec8a486

File tree

3 files changed

+6
-5
lines changed

3 files changed

+6
-5
lines changed

CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,10 @@ elseif (MLX_BUILD_METAL)
7777
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14.2_iOS17.2.zip)
7878
elseif (${MACOS_VERSION} GREATER_EQUAL 14.0)
7979
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14_iOS17-beta.zip)
80+
elseif (${MACOS_VERSION} GREATER_EQUAL 13.3)
81+
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS13.3_iOS16.4.zip)
8082
else ()
81-
message(FATAL_ERROR "MLX requires macOS >= 13.5 to be built with MLX_BUILD_METAL=ON")
83+
message(FATAL_ERROR "MLX requires macOS >= 13.3 to be built with MLX_BUILD_METAL=ON")
8284
endif()
8385

8486
FetchContent_Declare(

mlx/backend/metal/kernels/scaled_dot_product_attention.metal

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,12 @@ template<typename T, typename T2, typename T4, uint16_t TILE_SIZE_CONST, uint16_
1313
device float* O_partials [[buffer(5)]],
1414
device float* p_lse [[buffer(6)]],
1515
device float* p_maxes [[buffer(7)]],
16-
threadgroup T* threadgroup_block [[threadgroup(0)]],
1716
uint simd_lane_id [[thread_index_in_simdgroup]],
1817
uint simd_group_id [[simdgroup_index_in_threadgroup]],
1918
uint3 tid [[threadgroup_position_in_grid]]) {
19+
20+
threadgroup T threadgroup_block[32768 / sizeof(T)];
21+
2022
constexpr const size_t DK = 128;
2123
constexpr const ulong SIMDGROUP_MATRIX_LOAD_FACTOR = 8;
2224
constexpr const size_t THREADS_PER_SIMDGROUP = 32;
@@ -356,7 +358,6 @@ template [[host_name("fast_inference_sdpa_compute_partials_" #itype "_" #tile_si
356358
device float* O_partials [[buffer(5)]], \
357359
device float* p_lse [[buffer(6)]], \
358360
device float* p_maxes [[buffer(7)]], \
359-
threadgroup itype *threadgroup_block [[threadgroup(0)]], \
360361
uint simd_lane_id [[thread_index_in_simdgroup]], \
361362
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
362363
uint3 tid [[threadgroup_position_in_grid]]);

mlx/backend/metal/scaled_dot_product_attention.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,6 @@ void sdpa_metal(
9797
set_array_buffer(compute_encoder, p_lse, 6);
9898
set_array_buffer(compute_encoder, p_rowmaxes, 7);
9999

100-
constexpr const uint tgroupMemorySize = 32768;
101-
compute_encoder->setThreadgroupMemoryLength(tgroupMemorySize, 0);
102100
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
103101

104102
{

0 commit comments

Comments
 (0)