File tree Expand file tree Collapse file tree 3 files changed +5
-26
lines changed
Expand file tree Collapse file tree 3 files changed +5
-26
lines changed Original file line number Diff line number Diff line change @@ -462,13 +462,12 @@ void ScaledDotProductAttention::eval_gpu(
462462bool ScaledDotProductAttentionVJP::use_fallback (
463463 const array& q,
464464 Stream s,
465- bool has_mask,
466- bool has_sinks,
465+ bool /* has_mask */ ,
466+ bool /* has_sinks */ ,
467467 int /* n_kv_heads */ ) {
468- // Force unfused attention when masks/sinks present
469- if (has_mask || has_sinks) {
470- return true ;
471- }
468+ // Note: cuDNN SDPA backward correctly handles masks/sinks,
469+ // so we don't need to force fallback based on their presence.
470+
472471 // The frontend adds a padding mask when sequence length is not a multiple of
473472 // tile size.
474473 if (q.shape (2 ) % 128 != 0 ) {
Original file line number Diff line number Diff line change @@ -777,11 +777,6 @@ void ScaledDotProductAttention::eval_gpu(
777777 output_logsumexp_,
778778 lse_out);
779779
780- // Cache logsumexp for VJP access (handles both cases: in outputs[1] or
781- // separate array)
782- if (output_logsumexp_ && lse_out != nullptr ) {
783- set_cached_logsumexp (*lse_out);
784- }
785780 }
786781
787782 d.add_temporaries (std::move (copies), s.index );
Original file line number Diff line number Diff line change @@ -257,21 +257,6 @@ class ScaledDotProductAttention : public Custom {
257257 bool do_causal_;
258258 bool has_sinks_;
259259 bool output_logsumexp_;
260- // Cache logsumexp for VJP backward pass
261- // This enables Flash Attention VJP to access logsumexp even when
262- // the forward pass returns only the attention output to the user.
263- // Size is small: batch * heads * seq * 1 * sizeof(float) = ~512KB per layer
264- mutable std::optional<array> cached_logsumexp_;
265-
266- public:
267- // Getter for VJP to access cached logsumexp
268- const std::optional<array>& get_cached_logsumexp () const {
269- return cached_logsumexp_;
270- }
271- // Setter called during eval_gpu
272- void set_cached_logsumexp (array logsumexp) const {
273- cached_logsumexp_ = std::move (logsumexp);
274- }
275260};
276261
277262class ScaledDotProductAttentionVJP : public Custom {
You can’t perform that action at this time.
0 commit comments