Skip to content
Closed
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions keras/src/backend/jax/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1276,6 +1276,10 @@ def dot_product_attention(
# use flash attention
_can_use_flash_attention(query, key, value, bias, raise_error=True)

# On TPU, traced masks cause ConcretizationTypeError with flash attention.
if is_tpu and mask is not None and isinstance(mask, jax.core.Tracer):
flash_attention = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because on JAX we compile by default, this means that flash_attention is now always disabled. The whole flash attention implementation was for TPU specifically. So I don't think this is the right fix.

Am I misunderstanding something?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hertschuh You're absolutely right! Thank you for catching that.

My initial fix using isinstance(mask, jax.core.Tracer) was too broad and would have disabled flash attention even during normal JIT compilation, not just inside control flow operations.

I've revised the fix to instead:

  1. Removed the Tracer check
  2. Improved the exception handling to specifically catch jax.errors.ConcretizationTypeError

The fix now correctly handles traced masks without disabling flash attention for normal use cases.


# TPU-specific flash attention path
if is_tpu and flash_attention:
# Get sharding parameters from distribution context
Expand Down