Skip to content

Commit 9008d2e

Browse files
committed
Apply cr review
1 parent 0062590 commit 9008d2e

File tree

3 files changed

+5
-26
lines changed

3 files changed

+5
-26
lines changed

mlx/backend/cuda/scaled_dot_product_attention.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -462,13 +462,12 @@ void ScaledDotProductAttention::eval_gpu(
462462
bool ScaledDotProductAttentionVJP::use_fallback(
463463
const array& q,
464464
Stream s,
465-
bool has_mask,
466-
bool has_sinks,
465+
bool /* has_mask */,
466+
bool /* has_sinks */,
467467
int /* n_kv_heads */) {
468-
// Force unfused attention when masks/sinks present
469-
if (has_mask || has_sinks) {
470-
return true;
471-
}
468+
// Note: cuDNN SDPA backward correctly handles masks/sinks,
469+
// so we don't need to force fallback based on their presence.
470+
472471
// The frontend adds a padding mask when sequence length is not a multiple of
473472
// tile size.
474473
if (q.shape(2) % 128 != 0) {

mlx/backend/metal/scaled_dot_product_attention.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -777,11 +777,6 @@ void ScaledDotProductAttention::eval_gpu(
777777
output_logsumexp_,
778778
lse_out);
779779

780-
// Cache logsumexp for VJP access (handles both cases: in outputs[1] or
781-
// separate array)
782-
if (output_logsumexp_ && lse_out != nullptr) {
783-
set_cached_logsumexp(*lse_out);
784-
}
785780
}
786781

787782
d.add_temporaries(std::move(copies), s.index);

mlx/fast_primitives.h

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -257,21 +257,6 @@ class ScaledDotProductAttention : public Custom {
257257
bool do_causal_;
258258
bool has_sinks_;
259259
bool output_logsumexp_;
260-
// Cache logsumexp for VJP backward pass
261-
// This enables Flash Attention VJP to access logsumexp even when
262-
// the forward pass returns only the attention output to the user.
263-
// Size is small: batch * heads * seq * 1 * sizeof(float) = ~512KB per layer
264-
mutable std::optional<array> cached_logsumexp_;
265-
266-
public:
267-
// Getter for VJP to access cached logsumexp
268-
const std::optional<array>& get_cached_logsumexp() const {
269-
return cached_logsumexp_;
270-
}
271-
// Setter called during eval_gpu
272-
void set_cached_logsumexp(array logsumexp) const {
273-
cached_logsumexp_ = std::move(logsumexp);
274-
}
275260
};
276261

277262
class ScaledDotProductAttentionVJP : public Custom {

0 commit comments

Comments
 (0)