99#include " mlx/fast_primitives.h"
1010#include " mlx/ops.h"
1111#include " mlx/transforms.h"
12- #include " mlx/transforms_impl.h"
1312
1413namespace mlx ::core::fast {
1514
@@ -112,7 +111,8 @@ array rms_norm(
112111
113112 auto passed_weight =
114113 (has_weight) ? astype (*weight, out_type, s) : array (1 , out_type);
115- if (s.device == Device::gpu) {
114+
115+ if (!RMSNorm::use_fallback (s)) {
116116 return array (
117117 x.shape (),
118118 out_type,
@@ -256,7 +256,7 @@ array layer_norm(
256256 auto passed_bias =
257257 (has_bias) ? astype (*bias, out_type, s) : array (0 , out_type);
258258
259- if (s. device == Device::gpu ) {
259+ if (! LayerNorm::use_fallback (s) ) {
260260 return array (
261261 x.shape (),
262262 out_type,
@@ -470,7 +470,7 @@ array rope(
470470 }
471471 };
472472 auto stream = to_stream (s);
473- if (stream. device == Device::gpu ) {
473+ if (! RoPE::use_fallback (stream) ) {
474474 return array (
475475 x.shape (),
476476 x.dtype (),
@@ -727,31 +727,6 @@ array scaled_dot_product_attention(
727727 };
728728
729729 auto stream = to_stream (s);
730- const int value_head_dim = v.shape (-1 );
731- const int query_head_dim = q.shape (-1 );
732- const int query_sequence_length = q.shape (2 );
733- const int key_sequence_length = k.shape (2 );
734-
735- const bool sdpa_vector_supported_head_dim =
736- query_head_dim == value_head_dim &&
737- (query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128 ||
738- query_head_dim == 256 );
739- const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim &&
740- (query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128 );
741-
742- const bool sdpa_full_supported_mask = !has_mask || has_arr_mask ||
743- (query_sequence_length <= key_sequence_length && do_causal);
744-
745- const bool supports_sdpa_full = sdpa_full_supported_mask &&
746- sdpa_full_supported_head_dim && stream.device == Device::gpu;
747-
748- const bool supports_sdpa_vector = (query_sequence_length <= 8 ) &&
749- (query_sequence_length <= key_sequence_length) &&
750- sdpa_vector_supported_head_dim && stream.device == Device::gpu;
751-
752- const bool implementation_supports_use_case =
753- supports_sdpa_full || supports_sdpa_vector;
754-
755730 std::vector<array> inputs = {q, k, v};
756731 if (has_arr_mask) {
757732 // Check type
@@ -770,7 +745,8 @@ array scaled_dot_product_attention(
770745 mask_shape.back () = keys.shape (-2 );
771746 inputs.push_back (broadcast_to (mask_arr, mask_shape, stream));
772747 }
773- if (!detail::in_grad_tracing () && implementation_supports_use_case) {
748+ if (!ScaledDotProductAttention::use_fallback (
749+ q, k, v, has_mask, has_arr_mask, do_causal, stream)) {
774750 auto out_shape = Shape{q.shape (0 ), q.shape (1 ), q.shape (2 ), v.shape (-1 )};
775751 return array (
776752 std::move (out_shape),
@@ -779,7 +755,7 @@ array scaled_dot_product_attention(
779755 stream, fallback, scale, do_causal),
780756 std::move (inputs));
781757 }
782- return fallback (inputs)[0 ];
758+ return fallback (std::move ( inputs) )[0 ];
783759}
784760
785761bool ScaledDotProductAttention::is_equivalent (const Primitive& other) const {
0 commit comments