Skip to content

Commit dcf3e01

Browse files
committed
Add tests for JAX optimized GRU backend
1 parent 72d2b1a commit dcf3e01

File tree

1 file changed

+349
-0
lines changed

1 file changed

+349
-0
lines changed

keras/src/backend/jax/rnn_test.py

Lines changed: 349 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,349 @@
1+
import numpy as np
2+
import pytest
3+
4+
from keras.src import backend
5+
from keras.src import testing
6+
7+
8+
def _np_sigmoid(x):
9+
return 1.0 / (1.0 + np.exp(-x))
10+
11+
12+
def _np_tanh(x):
13+
return np.tanh(x)
14+
15+
16+
def _gru_reference(
17+
inputs,
18+
initial_state,
19+
kernel,
20+
recurrent_kernel,
21+
bias,
22+
go_backwards=False,
23+
return_sequences=False,
24+
):
25+
"""Pure NumPy GRU reference implementation (reset_after=True)."""
26+
batch, timesteps, _ = inputs.shape
27+
hidden_size = recurrent_kernel.shape[0]
28+
29+
if bias is not None:
30+
input_bias = bias[0]
31+
recurrent_bias = bias[1]
32+
else:
33+
input_bias = np.zeros(3 * hidden_size)
34+
recurrent_bias = np.zeros(3 * hidden_size)
35+
36+
h = initial_state.copy()
37+
all_outputs = []
38+
39+
indices = range(timesteps)
40+
if go_backwards:
41+
indices = reversed(indices)
42+
43+
for t in indices:
44+
x_t = inputs[:, t, :]
45+
x_all = x_t @ kernel + input_bias
46+
h_all = h @ recurrent_kernel + recurrent_bias
47+
48+
x_z, x_r, x_h = np.split(x_all, 3, axis=-1)
49+
h_z, h_r, h_h = np.split(h_all, 3, axis=-1)
50+
51+
z = _np_sigmoid(x_z + h_z)
52+
r = _np_sigmoid(x_r + h_r)
53+
hh = _np_tanh(x_h + r * h_h)
54+
h = z * h + (1 - z) * hh
55+
all_outputs.append(h.copy())
56+
57+
if go_backwards:
58+
all_outputs = list(reversed(all_outputs))
59+
60+
outputs = np.stack(all_outputs, axis=1)
61+
last_output = h
62+
if not return_sequences:
63+
outputs = last_output[:, np.newaxis, :]
64+
return last_output, outputs, [h]
65+
66+
67+
@pytest.mark.skipif(
68+
backend.backend() != "jax",
69+
reason="JAX-specific optimized GRU tests.",
70+
)
71+
class JaxGRUTest(testing.TestCase):
72+
def _get_activations(self):
73+
"""Return JAX-compatible activation functions."""
74+
from jax import numpy as jnp
75+
from jax.nn import sigmoid
76+
77+
return jnp.tanh, sigmoid
78+
79+
def _get_test_weights(self, input_size, hidden_size, use_bias=True):
80+
rng = np.random.RandomState(42)
81+
kernel = rng.randn(input_size, 3 * hidden_size).astype("float32") * 0.1
82+
recurrent_kernel = (
83+
rng.randn(hidden_size, 3 * hidden_size).astype("float32") * 0.1
84+
)
85+
if use_bias:
86+
bias = rng.randn(2, 3 * hidden_size).astype("float32") * 0.1
87+
else:
88+
bias = None
89+
return kernel, recurrent_kernel, bias
90+
91+
def test_forward(self):
92+
from keras.src.backend.jax.rnn import gru
93+
94+
tanh, sigmoid = self._get_activations()
95+
batch, seq_len, input_size, hidden_size = 2, 5, 4, 3
96+
kernel, recurrent_kernel, bias = self._get_test_weights(
97+
input_size, hidden_size
98+
)
99+
100+
rng = np.random.RandomState(0)
101+
inputs = rng.randn(batch, seq_len, input_size).astype("float32")
102+
h_0 = np.zeros((batch, hidden_size), dtype="float32")
103+
104+
last_output, outputs, states = gru(
105+
inputs, h_0, None, kernel, recurrent_kernel, bias,
106+
tanh, sigmoid,
107+
)
108+
109+
ref_last, _, _ = _gru_reference(
110+
inputs, h_0, kernel, recurrent_kernel, bias,
111+
)
112+
113+
self.assertAllClose(last_output, ref_last, atol=1e-5)
114+
self.assertEqual(outputs.shape, (batch, 1, hidden_size))
115+
116+
def test_return_sequences(self):
117+
from keras.src.backend.jax.rnn import gru
118+
119+
tanh, sigmoid = self._get_activations()
120+
batch, seq_len, input_size, hidden_size = 2, 5, 4, 3
121+
kernel, recurrent_kernel, bias = self._get_test_weights(
122+
input_size, hidden_size
123+
)
124+
125+
rng = np.random.RandomState(0)
126+
inputs = rng.randn(batch, seq_len, input_size).astype("float32")
127+
h_0 = np.zeros((batch, hidden_size), dtype="float32")
128+
129+
last_output, outputs, states = gru(
130+
inputs, h_0, None, kernel, recurrent_kernel, bias,
131+
tanh, sigmoid, return_sequences=True,
132+
)
133+
134+
ref_last, ref_out, _ = _gru_reference(
135+
inputs, h_0, kernel, recurrent_kernel, bias,
136+
return_sequences=True,
137+
)
138+
139+
self.assertAllClose(last_output, ref_last, atol=1e-5)
140+
self.assertAllClose(outputs, ref_out, atol=1e-5)
141+
self.assertEqual(outputs.shape, (batch, seq_len, hidden_size))
142+
143+
def test_go_backwards(self):
144+
from keras.src.backend.jax.rnn import gru
145+
146+
tanh, sigmoid = self._get_activations()
147+
batch, seq_len, input_size, hidden_size = 2, 5, 4, 3
148+
kernel, recurrent_kernel, bias = self._get_test_weights(
149+
input_size, hidden_size
150+
)
151+
152+
rng = np.random.RandomState(0)
153+
inputs = rng.randn(batch, seq_len, input_size).astype("float32")
154+
h_0 = np.zeros((batch, hidden_size), dtype="float32")
155+
156+
last_output, outputs, _ = gru(
157+
inputs, h_0, None, kernel, recurrent_kernel, bias,
158+
tanh, sigmoid, go_backwards=True,
159+
)
160+
161+
ref_last, _, _ = _gru_reference(
162+
inputs, h_0, kernel, recurrent_kernel, bias,
163+
go_backwards=True,
164+
)
165+
166+
self.assertAllClose(last_output, ref_last, atol=1e-5)
167+
168+
def test_go_backwards_return_sequences(self):
169+
from keras.src.backend.jax.rnn import gru
170+
171+
tanh, sigmoid = self._get_activations()
172+
batch, seq_len, input_size, hidden_size = 2, 5, 4, 3
173+
kernel, recurrent_kernel, bias = self._get_test_weights(
174+
input_size, hidden_size
175+
)
176+
177+
rng = np.random.RandomState(0)
178+
inputs = rng.randn(batch, seq_len, input_size).astype("float32")
179+
h_0 = np.zeros((batch, hidden_size), dtype="float32")
180+
181+
last_output, outputs, _ = gru(
182+
inputs, h_0, None, kernel, recurrent_kernel, bias,
183+
tanh, sigmoid, go_backwards=True, return_sequences=True,
184+
)
185+
186+
ref_last, ref_out, _ = _gru_reference(
187+
inputs, h_0, kernel, recurrent_kernel, bias,
188+
go_backwards=True, return_sequences=True,
189+
)
190+
191+
self.assertAllClose(last_output, ref_last, atol=1e-5)
192+
self.assertAllClose(outputs, ref_out, atol=1e-5)
193+
194+
def test_nonzero_initial_state(self):
195+
from keras.src.backend.jax.rnn import gru
196+
197+
tanh, sigmoid = self._get_activations()
198+
batch, seq_len, input_size, hidden_size = 2, 5, 4, 3
199+
kernel, recurrent_kernel, bias = self._get_test_weights(
200+
input_size, hidden_size
201+
)
202+
203+
rng = np.random.RandomState(0)
204+
inputs = rng.randn(batch, seq_len, input_size).astype("float32")
205+
h_0 = rng.randn(batch, hidden_size).astype("float32") * 0.5
206+
207+
last_output, _, _ = gru(
208+
inputs, h_0, None, kernel, recurrent_kernel, bias,
209+
tanh, sigmoid,
210+
)
211+
212+
ref_last, _, _ = _gru_reference(
213+
inputs, h_0, kernel, recurrent_kernel, bias,
214+
)
215+
216+
self.assertAllClose(last_output, ref_last, atol=1e-5)
217+
218+
def test_no_bias(self):
219+
from keras.src.backend.jax.rnn import gru
220+
221+
tanh, sigmoid = self._get_activations()
222+
batch, seq_len, input_size, hidden_size = 2, 5, 4, 3
223+
kernel, recurrent_kernel, _ = self._get_test_weights(
224+
input_size, hidden_size, use_bias=False
225+
)
226+
227+
rng = np.random.RandomState(0)
228+
inputs = rng.randn(batch, seq_len, input_size).astype("float32")
229+
h_0 = np.zeros((batch, hidden_size), dtype="float32")
230+
231+
last_output, outputs, _ = gru(
232+
inputs, h_0, None, kernel, recurrent_kernel, None,
233+
tanh, sigmoid, return_sequences=True,
234+
)
235+
236+
ref_last, ref_out, _ = _gru_reference(
237+
inputs, h_0, kernel, recurrent_kernel, None,
238+
return_sequences=True,
239+
)
240+
241+
self.assertAllClose(last_output, ref_last, atol=1e-5)
242+
self.assertAllClose(outputs, ref_out, atol=1e-5)
243+
244+
def test_fallback_reset_after_false(self):
245+
from keras.src.backend.jax.rnn import gru
246+
247+
tanh, sigmoid = self._get_activations()
248+
batch, seq_len, input_size, hidden_size = 2, 5, 4, 3
249+
kernel, recurrent_kernel, bias = self._get_test_weights(
250+
input_size, hidden_size
251+
)
252+
253+
inputs = np.zeros((batch, seq_len, input_size), dtype="float32")
254+
h_0 = np.zeros((batch, hidden_size), dtype="float32")
255+
256+
with self.assertRaises(NotImplementedError):
257+
gru(
258+
inputs, h_0, None, kernel, recurrent_kernel, bias,
259+
tanh, sigmoid, reset_after=False,
260+
)
261+
262+
def test_fallback_unroll(self):
263+
from keras.src.backend.jax.rnn import gru
264+
265+
tanh, sigmoid = self._get_activations()
266+
batch, seq_len, input_size, hidden_size = 2, 5, 4, 3
267+
kernel, recurrent_kernel, bias = self._get_test_weights(
268+
input_size, hidden_size
269+
)
270+
271+
inputs = np.zeros((batch, seq_len, input_size), dtype="float32")
272+
h_0 = np.zeros((batch, hidden_size), dtype="float32")
273+
274+
with self.assertRaises(NotImplementedError):
275+
gru(
276+
inputs, h_0, None, kernel, recurrent_kernel, bias,
277+
tanh, sigmoid, unroll=True,
278+
)
279+
280+
def test_fallback_mask(self):
281+
from keras.src.backend.jax.rnn import gru
282+
283+
tanh, sigmoid = self._get_activations()
284+
batch, seq_len, input_size, hidden_size = 2, 5, 4, 3
285+
kernel, recurrent_kernel, bias = self._get_test_weights(
286+
input_size, hidden_size
287+
)
288+
289+
inputs = np.zeros((batch, seq_len, input_size), dtype="float32")
290+
h_0 = np.zeros((batch, hidden_size), dtype="float32")
291+
mask = np.ones((batch, seq_len), dtype="bool")
292+
293+
with self.assertRaises(NotImplementedError):
294+
gru(
295+
inputs, h_0, mask, kernel, recurrent_kernel, bias,
296+
tanh, sigmoid,
297+
)
298+
299+
def test_matches_layer_output(self):
300+
"""Verify the optimized path matches the layer's output."""
301+
from keras.src import initializers
302+
from keras.src import layers
303+
304+
sequence = np.arange(72).reshape((3, 6, 4)).astype("float32")
305+
layer = layers.GRU(
306+
3,
307+
kernel_initializer=initializers.Constant(0.01),
308+
recurrent_initializer=initializers.Constant(0.02),
309+
bias_initializer=initializers.Constant(0.03),
310+
)
311+
layer_output = layer(sequence)
312+
313+
self.assertAllClose(
314+
np.array(
315+
[
316+
[0.5217289, 0.5217289, 0.5217289],
317+
[0.6371659, 0.6371659, 0.6371659],
318+
[0.39384964, 0.39384964, 0.3938496],
319+
]
320+
),
321+
layer_output,
322+
atol=1e-5,
323+
)
324+
325+
def test_matches_layer_go_backwards(self):
326+
from keras.src import initializers
327+
from keras.src import layers
328+
329+
sequence = np.arange(72).reshape((3, 6, 4)).astype("float32")
330+
layer = layers.GRU(
331+
3,
332+
kernel_initializer=initializers.Constant(0.01),
333+
recurrent_initializer=initializers.Constant(0.02),
334+
bias_initializer=initializers.Constant(0.03),
335+
go_backwards=True,
336+
)
337+
layer_output = layer(sequence)
338+
339+
self.assertAllClose(
340+
np.array(
341+
[
342+
[0.24406259, 0.24406259, 0.24406259],
343+
[0.611516, 0.611516, 0.611516],
344+
[0.3928808, 0.3928808, 0.3928808],
345+
]
346+
),
347+
layer_output,
348+
atol=1e-5,
349+
)

0 commit comments

Comments
 (0)