Skip to content

Commit 22a8319

Browse files
committed
Add tests for JAX LSTM backend
1 parent 9e19f82 commit 22a8319

File tree

1 file changed

+218
-0
lines changed

1 file changed

+218
-0
lines changed

keras/src/backend/jax/rnn_test.py

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
import numpy as np
2+
import pytest
3+
4+
from keras.src import backend
5+
from keras.src import testing
6+
7+
8+
@pytest.mark.skipif(
9+
backend.backend() != "jax",
10+
reason="JAX-specific LSTM tests.",
11+
)
12+
class JaxLSTMTest(testing.TestCase):
13+
def test_cudnn_ok_standard(self):
14+
from jax import numpy as jnp
15+
16+
from keras.src import activations
17+
from keras.src import ops
18+
from keras.src.backend.jax.rnn import cudnn_ok
19+
20+
# These only return True when GPU is available, so on CPU
21+
# we just verify they return a bool and don't crash.
22+
result = cudnn_ok(activations.tanh, activations.sigmoid, False)
23+
self.assertIsInstance(result, (bool, np.bool_))
24+
25+
result = cudnn_ok(jnp.tanh, activations.sigmoid, False)
26+
self.assertIsInstance(result, (bool, np.bool_))
27+
28+
result = cudnn_ok(ops.tanh, ops.sigmoid, False)
29+
self.assertIsInstance(result, (bool, np.bool_))
30+
31+
def test_cudnn_ok_rejects_unroll(self):
32+
from keras.src import activations
33+
from keras.src.backend.jax.rnn import cudnn_ok
34+
35+
self.assertFalse(cudnn_ok(activations.tanh, activations.sigmoid, True))
36+
37+
def test_cudnn_ok_rejects_no_bias(self):
38+
from keras.src import activations
39+
from keras.src.backend.jax.rnn import cudnn_ok
40+
41+
self.assertFalse(
42+
cudnn_ok(
43+
activations.tanh, activations.sigmoid, False, use_bias=False
44+
)
45+
)
46+
47+
def test_cudnn_ok_rejects_wrong_activation(self):
48+
from keras.src import activations
49+
from keras.src.backend.jax.rnn import cudnn_ok
50+
51+
self.assertFalse(cudnn_ok(activations.relu, activations.sigmoid, False))
52+
self.assertFalse(cudnn_ok(activations.tanh, activations.tanh, False))
53+
54+
def test_assert_valid_mask_right_padded(self):
55+
from jax import numpy as jnp
56+
57+
from keras.src.backend.jax.rnn import _assert_valid_mask
58+
59+
mask = jnp.array(
60+
[[True, True, True, False], [True, True, False, False]]
61+
)
62+
# Should not raise.
63+
_assert_valid_mask(mask)
64+
65+
def test_assert_valid_mask_all_true(self):
66+
from jax import numpy as jnp
67+
68+
from keras.src.backend.jax.rnn import _assert_valid_mask
69+
70+
mask = jnp.ones((2, 5), dtype=jnp.bool_)
71+
_assert_valid_mask(mask)
72+
73+
def test_assert_valid_mask_not_right_padded(self):
74+
from jax import numpy as jnp
75+
76+
from keras.src.backend.jax.rnn import _assert_valid_mask
77+
78+
mask = jnp.array(
79+
[[True, False, True, False], [True, True, False, False]]
80+
)
81+
with self.assertRaises(ValueError):
82+
_assert_valid_mask(mask)
83+
84+
def test_assert_valid_mask_fully_masked(self):
85+
from jax import numpy as jnp
86+
87+
from keras.src.backend.jax.rnn import _assert_valid_mask
88+
89+
mask = jnp.array([[False, False, False], [True, True, False]])
90+
with self.assertRaises(ValueError):
91+
_assert_valid_mask(mask)
92+
93+
def test_lstm_raises_on_cpu(self):
94+
"""On CPU, lstm() should raise NotImplementedError."""
95+
from keras.src.backend.jax.rnn import lstm
96+
97+
batch, seq_len, input_size, hidden_size = 2, 5, 4, 3
98+
rng = np.random.RandomState(42)
99+
inputs = rng.randn(batch, seq_len, input_size).astype("float32")
100+
h_0 = np.zeros((batch, hidden_size), dtype="float32")
101+
c_0 = np.zeros((batch, hidden_size), dtype="float32")
102+
kernel = rng.randn(input_size, 4 * hidden_size).astype("float32")
103+
recurrent_kernel = rng.randn(hidden_size, 4 * hidden_size).astype(
104+
"float32"
105+
)
106+
bias = rng.randn(4 * hidden_size).astype("float32")
107+
108+
from keras.src import activations
109+
110+
# On CPU, cudnn_ok returns False, so this should raise.
111+
with self.assertRaises(NotImplementedError):
112+
lstm(
113+
inputs,
114+
h_0,
115+
c_0,
116+
None,
117+
kernel,
118+
recurrent_kernel,
119+
bias,
120+
activations.tanh,
121+
activations.sigmoid,
122+
)
123+
124+
def test_lstm_raises_unroll(self):
125+
from keras.src.backend.jax.rnn import lstm
126+
127+
batch, seq_len, input_size, hidden_size = 2, 5, 4, 3
128+
rng = np.random.RandomState(42)
129+
inputs = rng.randn(batch, seq_len, input_size).astype("float32")
130+
h_0 = np.zeros((batch, hidden_size), dtype="float32")
131+
c_0 = np.zeros((batch, hidden_size), dtype="float32")
132+
kernel = rng.randn(input_size, 4 * hidden_size).astype("float32")
133+
recurrent_kernel = rng.randn(hidden_size, 4 * hidden_size).astype(
134+
"float32"
135+
)
136+
bias = rng.randn(4 * hidden_size).astype("float32")
137+
138+
from keras.src import activations
139+
140+
with self.assertRaises(NotImplementedError):
141+
lstm(
142+
inputs,
143+
h_0,
144+
c_0,
145+
None,
146+
kernel,
147+
recurrent_kernel,
148+
bias,
149+
activations.tanh,
150+
activations.sigmoid,
151+
unroll=True,
152+
)
153+
154+
def test_layer_correctness(self):
155+
"""Verify LSTM layer produces correct output (falls back on CPU)."""
156+
from keras.src import initializers
157+
from keras.src import layers
158+
159+
sequence = np.arange(72).reshape((3, 6, 4)).astype("float32")
160+
layer = layers.LSTM(
161+
3,
162+
kernel_initializer=initializers.Constant(0.01),
163+
recurrent_initializer=initializers.Constant(0.02),
164+
bias_initializer=initializers.Constant(0.03),
165+
)
166+
output = layer(sequence)
167+
self.assertAllClose(
168+
np.array(
169+
[
170+
[0.6288687, 0.6288687, 0.6288687],
171+
[0.86899155, 0.86899155, 0.86899155],
172+
[0.9460773, 0.9460773, 0.9460773],
173+
]
174+
),
175+
output,
176+
atol=1e-5,
177+
)
178+
179+
def test_layer_go_backwards(self):
180+
from keras.src import initializers
181+
from keras.src import layers
182+
183+
sequence = np.arange(72).reshape((3, 6, 4)).astype("float32")
184+
layer = layers.LSTM(
185+
3,
186+
kernel_initializer=initializers.Constant(0.01),
187+
recurrent_initializer=initializers.Constant(0.02),
188+
bias_initializer=initializers.Constant(0.03),
189+
go_backwards=True,
190+
)
191+
output = layer(sequence)
192+
self.assertAllClose(
193+
np.array(
194+
[
195+
[0.35622165, 0.35622165, 0.35622165],
196+
[0.74789524, 0.74789524, 0.74789524],
197+
[0.8872726, 0.8872726, 0.8872726],
198+
]
199+
),
200+
output,
201+
atol=1e-5,
202+
)
203+
204+
def test_layer_return_state(self):
205+
from keras.src import initializers
206+
from keras.src import layers
207+
208+
sequence = np.arange(24).reshape((2, 4, 3)).astype("float32")
209+
layer = layers.LSTM(
210+
2,
211+
kernel_initializer=initializers.Constant(0.01),
212+
recurrent_initializer=initializers.Constant(0.02),
213+
bias_initializer=initializers.Constant(0.03),
214+
return_state=True,
215+
)
216+
output, state_h, state_c = layer(sequence)
217+
self.assertAllClose(output, state_h, atol=1e-5)
218+
self.assertEqual(state_c.shape, (2, 2))

0 commit comments

Comments
 (0)