Skip to content
Closed
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion keras/src/backend/jax/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1353,7 +1353,10 @@ def dot_product_attention(
)
# Transpose output back to Keras layout
return jnp.transpose(output, axes=(0, 2, 1, 3))
except Exception:
except jax.errors.ConcretizationTypeError:
# Mask is traced
# Fall back to native attention
flash_attention = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

flash_attention = False

This is already line 1364, remove.

logging.exception(
"Failed to apply Splash kernel for flash attention. "
"Falling back to JAX native dot_product_attention."
Expand Down
Loading