Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 25 additions & 18 deletions keras/src/backend/jax/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1340,25 +1340,32 @@ def dot_product_attention(
if custom_mask is None and is_causal:
custom_mask = jnp.tril(jnp.ones((q_len, q_len), dtype=jnp.bool_))

try:
output = wrap_flash_attention(
query_tpu_layout,
key_tpu_layout,
value_tpu_layout,
decoder_segment_ids=decoder_segment_ids,
custom_mask=custom_mask,
attn_logits_soft_cap=attn_logits_soft_cap,
head_shards=head_shards,
q_seq_shards=q_seq_shards,
)
# Transpose output back to Keras layout
return jnp.transpose(output, axes=(0, 2, 1, 3))
except Exception:
logging.exception(
"Failed to apply Splash kernel for flash attention. "
"Falling back to JAX native dot_product_attention."
)
# Splash attention kernel requires concrete mask values for hashing.
# If the mask is a tracer (e.g. inside a scan/loop), we must fall back.
if isinstance(mask, jax.core.Tracer) or isinstance(
custom_mask, jax.core.Tracer
):
flash_attention = False
else:
try:
output = wrap_flash_attention(
query_tpu_layout,
key_tpu_layout,
value_tpu_layout,
decoder_segment_ids=decoder_segment_ids,
custom_mask=custom_mask,
attn_logits_soft_cap=attn_logits_soft_cap,
head_shards=head_shards,
q_seq_shards=q_seq_shards,
)
# Transpose output back to Keras layout
return jnp.transpose(output, axes=(0, 2, 1, 3))
except Exception:
logging.exception(
"Failed to apply Splash kernel for flash attention. "
"Falling back to JAX native dot_product_attention."
)
flash_attention = False

# JAX native dot_product_attention for GPU or fallback for TPU
if hasattr(jax.nn, "dot_product_attention"):
Expand Down
34 changes: 34 additions & 0 deletions keras/src/ops/nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1324,6 +1324,40 @@ def test_polar(self):


class NNOpsCorrectnessTest(testing.TestCase):
@pytest.mark.skipif(backend.backend() != "jax", reason="JAX only")
def test_dot_product_attention_inside_scan(self):
import jax

try:
if jax.devices()[0].platform != "tpu":
self.skipTest("TPU-specific test")
except:
self.skipTest("TPU-specific test")
Comment on lines +1334 to +1335
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using a bare except: clause is generally discouraged as it can catch unexpected errors like SystemExit or KeyboardInterrupt, making it harder to debug issues. It's better to be more specific and catch Exception instead.

Suggested change
except:
self.skipTest("TPU-specific test")
except Exception:
self.skipTest("TPU-specific test")


import jax.numpy as jnp

def attention_scan_body(carry, x):
query, key, value = x
# dot_product_attention expects 4D inputs (B, H, S, D)
query = jnp.expand_dims(query, axis=0)
key = jnp.expand_dims(key, axis=0)
value = jnp.expand_dims(value, axis=0)

# Use a mask to trigger the issue
mask = jnp.ones((1, 4, 8), dtype="bool")
out = knn.dot_product_attention(query, key, value, mask=mask)

out = jnp.squeeze(out, axis=0)
return carry, out

query = jnp.ones((2, 1, 4, 8))
key = jnp.ones((2, 1, 4, 8))
value = jnp.ones((2, 1, 4, 8))

# Scan over the first dimension
_, out = jax.lax.scan(attention_scan_body, None, (query, key, value))
self.assertEqual(out.shape, (2, 1, 4, 8))

def test_relu(self):
x = np.array([-1, 0, 1, 2, 3], dtype=np.float32)
self.assertAllClose(knn.relu(x), [0, 0, 1, 2, 3])
Expand Down
Loading