Skip to content
28 changes: 28 additions & 0 deletions keras/src/ops/nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1324,6 +1324,34 @@ def test_polar(self):


class NNOpsCorrectnessTest(testing.TestCase):
@pytest.mark.skipif(backend.backend() != "jax", reason="JAX only")
def test_dot_product_attention_inside_scan(self):

import jax

try:
if jax.devices()[0].platform != "tpu":
self.skipTest("TPU-specific test")
except:
self.skipTest("TPU-specific test")
Comment on lines +1334 to +1335
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using a bare except: clause is generally discouraged as it can catch unexpected errors like SystemExit or KeyboardInterrupt, making it harder to debug issues. It's better to be more specific and catch Exception instead.

Suggested change
except:
self.skipTest("TPU-specific test")
except Exception:
self.skipTest("TPU-specific test")


import jax.numpy as jnp

def attention_scan_body(carry, x):
query, key, value = x
# Use a mask to trigger the issue
mask = jnp.ones((1, 4, 8), dtype="bool")
out = knn.dot_product_attention(query, key, value, mask=mask)
return carry, out

query = jnp.ones((2, 1, 4, 8))
key = jnp.ones((2, 1, 4, 8))
value = jnp.ones((2, 1, 4, 8))

# Scan over the first dimension
_, out = jax.lax.scan(attention_scan_body, None, (query, key, value))
self.assertEqual(out.shape, (2, 1, 4, 8))

def test_relu(self):
x = np.array([-1, 0, 1, 2, 3], dtype=np.float32)
self.assertAllClose(knn.relu(x), [0, 0, 1, 2, 3])
Expand Down
Loading