Skip to content

Commit 8fe8ff1

Browse files
committed
Add tests for JAX LSTM backend
1 parent 9e19f82 commit 8fe8ff1

File tree

1 file changed

+212
-0
lines changed

1 file changed

+212
-0
lines changed

keras/src/backend/jax/rnn_test.py

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
import numpy as np
2+
import pytest
3+
4+
from keras.src import backend
5+
from keras.src import testing
6+
7+
8+
@pytest.mark.skipif(
9+
backend.backend() != "jax",
10+
reason="JAX-specific LSTM tests.",
11+
)
12+
class JaxLSTMTest(testing.TestCase):
13+
def test_cudnn_ok_standard(self):
14+
from jax import numpy as jnp
15+
16+
from keras.src import activations
17+
from keras.src import ops
18+
from keras.src.backend.jax.rnn import cudnn_ok
19+
20+
# These only return True when GPU is available, so on CPU
21+
# we just verify they return a bool and don't crash.
22+
result = cudnn_ok(activations.tanh, activations.sigmoid, False)
23+
self.assertIsInstance(result, (bool, np.bool_))
24+
25+
result = cudnn_ok(jnp.tanh, activations.sigmoid, False)
26+
self.assertIsInstance(result, (bool, np.bool_))
27+
28+
result = cudnn_ok(ops.tanh, ops.sigmoid, False)
29+
self.assertIsInstance(result, (bool, np.bool_))
30+
31+
def test_cudnn_ok_rejects_unroll(self):
32+
from keras.src import activations
33+
from keras.src.backend.jax.rnn import cudnn_ok
34+
35+
self.assertFalse(cudnn_ok(activations.tanh, activations.sigmoid, True))
36+
37+
def test_cudnn_ok_rejects_no_bias(self):
38+
from keras.src import activations
39+
from keras.src.backend.jax.rnn import cudnn_ok
40+
41+
self.assertFalse(
42+
cudnn_ok(
43+
activations.tanh, activations.sigmoid, False, use_bias=False
44+
)
45+
)
46+
47+
def test_cudnn_ok_rejects_wrong_activation(self):
48+
from keras.src import activations
49+
from keras.src.backend.jax.rnn import cudnn_ok
50+
51+
self.assertFalse(
52+
cudnn_ok(activations.relu, activations.sigmoid, False)
53+
)
54+
self.assertFalse(
55+
cudnn_ok(activations.tanh, activations.tanh, False)
56+
)
57+
58+
def test_assert_valid_mask_right_padded(self):
59+
from jax import numpy as jnp
60+
61+
from keras.src.backend.jax.rnn import _assert_valid_mask
62+
63+
mask = jnp.array(
64+
[[True, True, True, False], [True, True, False, False]]
65+
)
66+
# Should not raise.
67+
_assert_valid_mask(mask)
68+
69+
def test_assert_valid_mask_all_true(self):
70+
from jax import numpy as jnp
71+
72+
from keras.src.backend.jax.rnn import _assert_valid_mask
73+
74+
mask = jnp.ones((2, 5), dtype=jnp.bool_)
75+
_assert_valid_mask(mask)
76+
77+
def test_assert_valid_mask_not_right_padded(self):
78+
from jax import numpy as jnp
79+
80+
from keras.src.backend.jax.rnn import _assert_valid_mask
81+
82+
mask = jnp.array(
83+
[[True, False, True, False], [True, True, False, False]]
84+
)
85+
with self.assertRaises(ValueError):
86+
_assert_valid_mask(mask)
87+
88+
def test_assert_valid_mask_fully_masked(self):
89+
from jax import numpy as jnp
90+
91+
from keras.src.backend.jax.rnn import _assert_valid_mask
92+
93+
mask = jnp.array(
94+
[[False, False, False], [True, True, False]]
95+
)
96+
with self.assertRaises(ValueError):
97+
_assert_valid_mask(mask)
98+
99+
def test_lstm_raises_on_cpu(self):
100+
"""On CPU, lstm() should raise NotImplementedError."""
101+
from keras.src.backend.jax.rnn import lstm
102+
103+
batch, seq_len, input_size, hidden_size = 2, 5, 4, 3
104+
rng = np.random.RandomState(42)
105+
inputs = rng.randn(batch, seq_len, input_size).astype("float32")
106+
h_0 = np.zeros((batch, hidden_size), dtype="float32")
107+
c_0 = np.zeros((batch, hidden_size), dtype="float32")
108+
kernel = rng.randn(input_size, 4 * hidden_size).astype("float32")
109+
recurrent_kernel = rng.randn(
110+
hidden_size, 4 * hidden_size
111+
).astype("float32")
112+
bias = rng.randn(4 * hidden_size).astype("float32")
113+
114+
from keras.src import activations
115+
116+
# On CPU, cudnn_ok returns False, so this should raise.
117+
with self.assertRaises(NotImplementedError):
118+
lstm(
119+
inputs, h_0, c_0, None,
120+
kernel, recurrent_kernel, bias,
121+
activations.tanh, activations.sigmoid,
122+
)
123+
124+
def test_lstm_raises_unroll(self):
125+
from keras.src.backend.jax.rnn import lstm
126+
127+
batch, seq_len, input_size, hidden_size = 2, 5, 4, 3
128+
rng = np.random.RandomState(42)
129+
inputs = rng.randn(batch, seq_len, input_size).astype("float32")
130+
h_0 = np.zeros((batch, hidden_size), dtype="float32")
131+
c_0 = np.zeros((batch, hidden_size), dtype="float32")
132+
kernel = rng.randn(input_size, 4 * hidden_size).astype("float32")
133+
recurrent_kernel = rng.randn(
134+
hidden_size, 4 * hidden_size
135+
).astype("float32")
136+
bias = rng.randn(4 * hidden_size).astype("float32")
137+
138+
from keras.src import activations
139+
140+
with self.assertRaises(NotImplementedError):
141+
lstm(
142+
inputs, h_0, c_0, None,
143+
kernel, recurrent_kernel, bias,
144+
activations.tanh, activations.sigmoid,
145+
unroll=True,
146+
)
147+
148+
def test_layer_correctness(self):
149+
"""Verify LSTM layer produces correct output (falls back on CPU)."""
150+
from keras.src import initializers
151+
from keras.src import layers
152+
153+
sequence = np.arange(72).reshape((3, 6, 4)).astype("float32")
154+
layer = layers.LSTM(
155+
3,
156+
kernel_initializer=initializers.Constant(0.01),
157+
recurrent_initializer=initializers.Constant(0.02),
158+
bias_initializer=initializers.Constant(0.03),
159+
)
160+
output = layer(sequence)
161+
self.assertAllClose(
162+
np.array(
163+
[
164+
[0.6288687, 0.6288687, 0.6288687],
165+
[0.86899155, 0.86899155, 0.86899155],
166+
[0.9460773, 0.9460773, 0.9460773],
167+
]
168+
),
169+
output,
170+
atol=1e-5,
171+
)
172+
173+
def test_layer_go_backwards(self):
174+
from keras.src import initializers
175+
from keras.src import layers
176+
177+
sequence = np.arange(72).reshape((3, 6, 4)).astype("float32")
178+
layer = layers.LSTM(
179+
3,
180+
kernel_initializer=initializers.Constant(0.01),
181+
recurrent_initializer=initializers.Constant(0.02),
182+
bias_initializer=initializers.Constant(0.03),
183+
go_backwards=True,
184+
)
185+
output = layer(sequence)
186+
self.assertAllClose(
187+
np.array(
188+
[
189+
[0.35622165, 0.35622165, 0.35622165],
190+
[0.74789524, 0.74789524, 0.74789524],
191+
[0.8872726, 0.8872726, 0.8872726],
192+
]
193+
),
194+
output,
195+
atol=1e-5,
196+
)
197+
198+
def test_layer_return_state(self):
199+
from keras.src import initializers
200+
from keras.src import layers
201+
202+
sequence = np.arange(24).reshape((2, 4, 3)).astype("float32")
203+
layer = layers.LSTM(
204+
2,
205+
kernel_initializer=initializers.Constant(0.01),
206+
recurrent_initializer=initializers.Constant(0.02),
207+
bias_initializer=initializers.Constant(0.03),
208+
return_state=True,
209+
)
210+
output, state_h, state_c = layer(sequence)
211+
self.assertAllClose(output, state_h, atol=1e-5)
212+
self.assertEqual(state_c.shape, (2, 2))

0 commit comments

Comments
 (0)