Skip to content

Commit 19fb69e

Browse files
authored
Add memory_efficient_threshold kwarg to sdpa kernel (#1319)
Allows opt-in to memory efficient GPU shader at proscribed sequence length. Otherwise, utilizes aggregate MLX primitives for best latency.
1 parent 9231617 commit 19fb69e

File tree

4 files changed

+13
-4
lines changed

4 files changed

+13
-4
lines changed

mlx/fast.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,7 @@ array scaled_dot_product_attention(
465465
const array& values,
466466
const float scale,
467467
const std::optional<array>& mask,
468+
const std::optional<int>& memory_efficient_threshold,
468469
StreamOrDevice s) {
469470
for (const auto& tensor : {queries, keys, values}) {
470471
if (tensor.ndim() != 4) {
@@ -535,6 +536,11 @@ array scaled_dot_product_attention(
535536
* * dtype is not fp32 or fp16
536537
*/
537538

539+
int threshold = 1e6;
540+
if (memory_efficient_threshold.has_value()) {
541+
threshold = std::max(1, memory_efficient_threshold.value());
542+
}
543+
538544
bool needs_mask = mask.has_value();
539545
auto fallback = [scale, needs_mask, final_type, n_q_heads, n_kv_heads, &s](
540546
const std::vector<array>& inputs) {
@@ -581,9 +587,10 @@ array scaled_dot_product_attention(
581587
bool implementation_supports_use_case =
582588
supports_sdpa || supports_full_self_attention;
583589

584-
// disabling full self attention until perf is tuned;
585-
// likewise for sdpa
586-
implementation_supports_use_case &= false;
590+
// sdpa gpu shader is disabled except for memory efficient opt-in
591+
const int seq_for_threshold = queries.shape(2);
592+
bool use_memory_efficient_impl = seq_for_threshold >= threshold;
593+
implementation_supports_use_case &= use_memory_efficient_impl;
587594

588595
if (implementation_supports_use_case) {
589596
auto out_shape =

mlx/fast.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ array scaled_dot_product_attention(
3737
const array& values,
3838
const float scale,
3939
const std::optional<array>& mask = std::nullopt,
40+
const std::optional<int>& memory_efficient_threshold = std::nullopt,
4041
StreamOrDevice s = {});
4142

4243
std::tuple<array, array, array> affine_quantize(

python/src/fast.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ void init_fast(nb::module_& parent_module) {
112112
nb::kw_only(),
113113
"scale"_a,
114114
"mask"_a = nb::none(),
115+
"memory_efficient_threshold"_a = nb::none(),
115116
"stream"_a = nb::none(),
116117
nb::sig(
117118
"def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Union[None, array] = None, stream: Union[None, Stream, Device] = None) -> array"),

python/tests/test_fast_sdpa.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def test_fast_sdpa(self):
8686

8787
reference = mlx_primitives_sdpa_with_gqa(q_mlx, k_mlx, v_mlx, scale)
8888
o_mlx = mx.fast.scaled_dot_product_attention(
89-
q_mlx, k_mlx, v_mlx, scale=scale
89+
q_mlx, k_mlx, v_mlx, scale=scale, memory_efficient_threshold=2
9090
)
9191

9292
self.assertListEqual(list(reference.shape), list(o_mlx.shape))

0 commit comments

Comments
 (0)