Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion keras/src/backend/jax/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1353,7 +1353,7 @@ def dot_product_attention(
)
# Transpose output back to Keras layout
return jnp.transpose(output, axes=(0, 2, 1, 3))
except Exception:
except (jax.errors.ConcretizationTypeError, Exception):
logging.exception(
"Failed to apply Splash kernel for flash attention. "
"Falling back to JAX native dot_product_attention."
Expand Down
31 changes: 31 additions & 0 deletions keras/src/ops/nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1324,6 +1324,37 @@ def test_polar(self):


class NNOpsCorrectnessTest(testing.TestCase):
def test_dot_product_attention_inside_scan(self):
if backend.backend() != "jax":
self.skipTest("JAX-specific test")

import jax

try:
if jax.devices()[0].platform != "tpu":
self.skipTest("TPU-specific test")
except:
self.skipTest("TPU-specific test")
Comment on lines +1334 to +1335
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using a bare except: clause is generally discouraged as it can catch unexpected errors like SystemExit or KeyboardInterrupt, making it harder to debug issues. It's better to be more specific and catch Exception instead.

Suggested change
except:
self.skipTest("TPU-specific test")
except Exception:
self.skipTest("TPU-specific test")


import jax.numpy as jnp

from keras.src.backend.jax import nn as jax_nn

def attention_scan_body(carry, x):
query, key, value = x
# Use a mask to trigger the issue
mask = jnp.ones((1, 4, 8), dtype="bool")
out = jax_nn.dot_product_attention(query, key, value, mask=mask)
return carry, out

query = jnp.ones((2, 1, 4, 8))
key = jnp.ones((2, 1, 4, 8))
value = jnp.ones((2, 1, 4, 8))

# Scan over the first dimension
_, out = jax.lax.scan(attention_scan_body, None, (query, key, value))
self.assertEqual(out.shape, (2, 1, 4, 8))

def test_relu(self):
x = np.array([-1, 0, 1, 2, 3], dtype=np.float32)
self.assertAllClose(knn.relu(x), [0, 0, 1, 2, 3])
Expand Down
Loading