Skip to content

Commit 1b021f6

Browse files
authored
Fast primitives decide when to use the fallback (#2216)
1 parent 95b7551 commit 1b021f6

File tree

7 files changed

+115
-45
lines changed

7 files changed

+115
-45
lines changed

mlx/backend/cuda/primitives.cu

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,29 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
4343
});
4444
}
4545

46+
bool fast::ScaledDotProductAttention::use_fallback(
47+
const array& q,
48+
const array& k,
49+
const array& v,
50+
bool has_mask,
51+
bool has_arr_mask,
52+
bool do_causal,
53+
Stream s) {
54+
return true;
55+
}
56+
4657
#define NO_GPU_MULTI(func) \
4758
void func::eval_gpu( \
4859
const std::vector<array>& inputs, std::vector<array>& outputs) { \
4960
throw std::runtime_error(#func " has no CUDA implementation."); \
5061
}
5162

63+
#define NO_GPU_USE_FALLBACK(func) \
64+
bool func::use_fallback(Stream s) { \
65+
return true; \
66+
} \
67+
NO_GPU_MULTI(func)
68+
5269
#define NO_GPU(func) \
5370
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
5471
throw std::runtime_error(#func " has no CUDA implementation."); \
@@ -144,11 +161,11 @@ NO_GPU_MULTI(Eig)
144161
NO_GPU_MULTI(Eigh)
145162

146163
namespace fast {
147-
NO_GPU_MULTI(LayerNorm)
164+
NO_GPU_USE_FALLBACK(LayerNorm)
148165
NO_GPU_MULTI(LayerNormVJP)
149-
NO_GPU_MULTI(RMSNorm)
166+
NO_GPU_USE_FALLBACK(RMSNorm)
150167
NO_GPU_MULTI(RMSNormVJP)
151-
NO_GPU_MULTI(RoPE)
168+
NO_GPU_USE_FALLBACK(RoPE)
152169
NO_GPU(ScaledDotProductAttention)
153170
NO_GPU_MULTI(AffineQuantize)
154171
NO_GPU_MULTI(CustomKernel)

mlx/backend/metal/normalization.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010

1111
namespace mlx::core::fast {
1212

13+
bool RMSNorm::use_fallback(Stream s) {
14+
return s.device == Device::cpu;
15+
}
16+
1317
void RMSNorm::eval_gpu(
1418
const std::vector<array>& inputs,
1519
std::vector<array>& outputs) {
@@ -207,6 +211,10 @@ void RMSNormVJP::eval_gpu(
207211
}
208212
}
209213

214+
bool LayerNorm::use_fallback(Stream s) {
215+
return s.device == Device::cpu;
216+
}
217+
210218
void LayerNorm::eval_gpu(
211219
const std::vector<array>& inputs,
212220
std::vector<array>& outputs) {

mlx/backend/metal/rope.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ namespace mlx::core::fast {
77

88
constexpr int n_per_thread = 4;
99

10+
bool RoPE::use_fallback(Stream s) {
11+
return s.device == Device::cpu;
12+
}
13+
1014
void RoPE::eval_gpu(
1115
const std::vector<array>& inputs,
1216
std::vector<array>& outputs) {

mlx/backend/metal/scaled_dot_product_attention.cpp

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
#include "mlx/backend/common/compiled.h"
55
#include "mlx/backend/gpu/copy.h"
66
#include "mlx/backend/metal/device.h"
7-
87
#include "mlx/backend/metal/kernels/steel/attn/params.h"
98
#include "mlx/backend/metal/utils.h"
109
#include "mlx/fast_primitives.h"
10+
#include "mlx/transforms_impl.h"
1111
#include "mlx/utils.h"
1212

1313
namespace mlx::core::fast {
@@ -339,6 +339,46 @@ void sdpa_vector_2pass(
339339

340340
} // namespace
341341

342+
bool ScaledDotProductAttention::use_fallback(
343+
const array& q,
344+
const array& k,
345+
const array& v,
346+
bool has_mask,
347+
bool has_arr_mask,
348+
bool do_causal,
349+
Stream s) {
350+
if (detail::in_grad_tracing()) {
351+
return true;
352+
}
353+
if (s.device == Device::cpu) {
354+
return true;
355+
}
356+
357+
const int value_head_dim = v.shape(-1);
358+
const int query_head_dim = q.shape(-1);
359+
const int query_sequence_length = q.shape(2);
360+
const int key_sequence_length = k.shape(2);
361+
362+
const bool sdpa_vector_supported_head_dim =
363+
query_head_dim == value_head_dim &&
364+
(query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128 ||
365+
query_head_dim == 256);
366+
const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim &&
367+
(query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128);
368+
369+
const bool sdpa_full_supported_mask = !has_mask || has_arr_mask ||
370+
(query_sequence_length <= key_sequence_length && do_causal);
371+
372+
const bool supports_sdpa_full =
373+
sdpa_full_supported_mask && sdpa_full_supported_head_dim;
374+
375+
const bool supports_sdpa_vector = (query_sequence_length <= 8) &&
376+
(query_sequence_length <= key_sequence_length) &&
377+
sdpa_vector_supported_head_dim;
378+
379+
return !(supports_sdpa_full || supports_sdpa_vector);
380+
}
381+
342382
void ScaledDotProductAttention::eval_gpu(
343383
const std::vector<array>& inputs,
344384
array& out) {

mlx/backend/no_gpu/primitives.cpp

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,30 @@
1010
throw std::runtime_error(#func " has no GPU implementation."); \
1111
}
1212

13+
#define NO_GPU_USE_FALLBACK(func) \
14+
bool func::use_fallback(Stream s) { \
15+
return true; \
16+
} \
17+
NO_GPU_MULTI(func)
18+
1319
#define NO_GPU(func) \
1420
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
1521
throw std::runtime_error(#func " has no GPU implementation."); \
1622
}
1723

1824
namespace mlx::core {
1925

26+
bool fast::ScaledDotProductAttention::use_fallback(
27+
const array& q,
28+
const array& k,
29+
const array& v,
30+
bool has_mask,
31+
bool has_arr_mask,
32+
bool do_causal,
33+
Stream s) {
34+
return true;
35+
}
36+
2037
NO_GPU(Abs)
2138
NO_GPU(Add)
2239
NO_GPU(AddMM)
@@ -130,11 +147,11 @@ NO_GPU_MULTI(Eig)
130147
NO_GPU(View)
131148

132149
namespace fast {
133-
NO_GPU_MULTI(LayerNorm)
150+
NO_GPU_USE_FALLBACK(LayerNorm)
134151
NO_GPU_MULTI(LayerNormVJP)
135-
NO_GPU_MULTI(RMSNorm)
152+
NO_GPU_USE_FALLBACK(RMSNorm)
136153
NO_GPU_MULTI(RMSNormVJP)
137-
NO_GPU_MULTI(RoPE)
154+
NO_GPU_USE_FALLBACK(RoPE)
138155
NO_GPU(ScaledDotProductAttention)
139156
NO_GPU_MULTI(AffineQuantize)
140157
NO_GPU_MULTI(CustomKernel)

mlx/fast.cpp

Lines changed: 7 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
#include "mlx/fast_primitives.h"
1010
#include "mlx/ops.h"
1111
#include "mlx/transforms.h"
12-
#include "mlx/transforms_impl.h"
1312

1413
namespace mlx::core::fast {
1514

@@ -112,7 +111,8 @@ array rms_norm(
112111

113112
auto passed_weight =
114113
(has_weight) ? astype(*weight, out_type, s) : array(1, out_type);
115-
if (s.device == Device::gpu) {
114+
115+
if (!RMSNorm::use_fallback(s)) {
116116
return array(
117117
x.shape(),
118118
out_type,
@@ -256,7 +256,7 @@ array layer_norm(
256256
auto passed_bias =
257257
(has_bias) ? astype(*bias, out_type, s) : array(0, out_type);
258258

259-
if (s.device == Device::gpu) {
259+
if (!LayerNorm::use_fallback(s)) {
260260
return array(
261261
x.shape(),
262262
out_type,
@@ -470,7 +470,7 @@ array rope(
470470
}
471471
};
472472
auto stream = to_stream(s);
473-
if (stream.device == Device::gpu) {
473+
if (!RoPE::use_fallback(stream)) {
474474
return array(
475475
x.shape(),
476476
x.dtype(),
@@ -727,31 +727,6 @@ array scaled_dot_product_attention(
727727
};
728728

729729
auto stream = to_stream(s);
730-
const int value_head_dim = v.shape(-1);
731-
const int query_head_dim = q.shape(-1);
732-
const int query_sequence_length = q.shape(2);
733-
const int key_sequence_length = k.shape(2);
734-
735-
const bool sdpa_vector_supported_head_dim =
736-
query_head_dim == value_head_dim &&
737-
(query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128 ||
738-
query_head_dim == 256);
739-
const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim &&
740-
(query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128);
741-
742-
const bool sdpa_full_supported_mask = !has_mask || has_arr_mask ||
743-
(query_sequence_length <= key_sequence_length && do_causal);
744-
745-
const bool supports_sdpa_full = sdpa_full_supported_mask &&
746-
sdpa_full_supported_head_dim && stream.device == Device::gpu;
747-
748-
const bool supports_sdpa_vector = (query_sequence_length <= 8) &&
749-
(query_sequence_length <= key_sequence_length) &&
750-
sdpa_vector_supported_head_dim && stream.device == Device::gpu;
751-
752-
const bool implementation_supports_use_case =
753-
supports_sdpa_full || supports_sdpa_vector;
754-
755730
std::vector<array> inputs = {q, k, v};
756731
if (has_arr_mask) {
757732
// Check type
@@ -770,7 +745,8 @@ array scaled_dot_product_attention(
770745
mask_shape.back() = keys.shape(-2);
771746
inputs.push_back(broadcast_to(mask_arr, mask_shape, stream));
772747
}
773-
if (!detail::in_grad_tracing() && implementation_supports_use_case) {
748+
if (!ScaledDotProductAttention::use_fallback(
749+
q, k, v, has_mask, has_arr_mask, do_causal, stream)) {
774750
auto out_shape = Shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)};
775751
return array(
776752
std::move(out_shape),
@@ -779,7 +755,7 @@ array scaled_dot_product_attention(
779755
stream, fallback, scale, do_causal),
780756
std::move(inputs));
781757
}
782-
return fallback(inputs)[0];
758+
return fallback(std::move(inputs))[0];
783759
}
784760

785761
bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const {

mlx/fast_primitives.h

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ class RMSNorm : public Custom {
4343
float eps)
4444
: Custom(stream, fallback), eps_(eps) {}
4545

46+
static bool use_fallback(Stream stream);
47+
4648
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
4749
override {
4850
throw std::runtime_error("NYI");
@@ -65,7 +67,6 @@ class RMSNorm : public Custom {
6567
}
6668

6769
private:
68-
std::function<std::vector<array>(std::vector<array>)> fallback_;
6970
float eps_;
7071
};
7172

@@ -91,7 +92,6 @@ class RMSNormVJP : public Custom {
9192
}
9293

9394
private:
94-
std::function<std::vector<array>(std::vector<array>)> fallback_;
9595
float eps_;
9696
};
9797

@@ -103,6 +103,8 @@ class LayerNorm : public Custom {
103103
float eps)
104104
: Custom(stream, fallback), eps_(eps) {}
105105

106+
static bool use_fallback(Stream s);
107+
106108
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
107109
override {
108110
throw std::runtime_error("NYI");
@@ -124,7 +126,6 @@ class LayerNorm : public Custom {
124126
}
125127

126128
private:
127-
std::function<std::vector<array>(std::vector<array>)> fallback_;
128129
float eps_;
129130
};
130131

@@ -150,7 +151,6 @@ class LayerNormVJP : public Custom {
150151
}
151152

152153
private:
153-
std::function<std::vector<array>(std::vector<array>)> fallback_;
154154
float eps_;
155155
};
156156

@@ -171,6 +171,8 @@ class RoPE : public Custom {
171171
scale_(scale),
172172
forward_(forward) {}
173173

174+
static bool use_fallback(Stream s);
175+
174176
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
175177
override {
176178
throw std::runtime_error("NYI");
@@ -193,7 +195,6 @@ class RoPE : public Custom {
193195
}
194196

195197
private:
196-
std::function<std::vector<array>(std::vector<array>)> fallback_;
197198
int dims_;
198199
bool traditional_;
199200
float base_;
@@ -210,6 +211,15 @@ class ScaledDotProductAttention : public Custom {
210211
const bool do_causal)
211212
: Custom(stream, fallback), scale_(scale), do_causal_(do_causal) {}
212213

214+
static bool use_fallback(
215+
const array& q,
216+
const array& k,
217+
const array& v,
218+
bool has_mask,
219+
bool has_arr_mask,
220+
bool do_causal,
221+
Stream s);
222+
213223
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
214224
override {
215225
throw std::runtime_error("NYI");
@@ -230,7 +240,6 @@ class ScaledDotProductAttention : public Custom {
230240
}
231241

232242
private:
233-
std::function<std::vector<array>(std::vector<array>)> fallback_;
234243
float scale_;
235244
bool do_causal_;
236245
};
@@ -263,7 +272,6 @@ class AffineQuantize : public Custom {
263272
}
264273

265274
private:
266-
std::function<std::vector<array>(std::vector<array>)> fallback_;
267275
int group_size_;
268276
int bits_;
269277
bool dequantize_;

0 commit comments

Comments
 (0)