|
| 1 | +import numpy as np |
| 2 | +import pytest |
| 3 | + |
| 4 | +from keras.src import backend |
| 5 | +from keras.src import testing |
| 6 | + |
| 7 | + |
| 8 | +def _np_sigmoid(x): |
| 9 | + return 1.0 / (1.0 + np.exp(-x)) |
| 10 | + |
| 11 | + |
| 12 | +def _np_tanh(x): |
| 13 | + return np.tanh(x) |
| 14 | + |
| 15 | + |
| 16 | +def _gru_reference( |
| 17 | + inputs, |
| 18 | + initial_state, |
| 19 | + kernel, |
| 20 | + recurrent_kernel, |
| 21 | + bias, |
| 22 | + go_backwards=False, |
| 23 | + return_sequences=False, |
| 24 | +): |
| 25 | + """Pure NumPy GRU reference implementation (reset_after=True).""" |
| 26 | + batch, timesteps, _ = inputs.shape |
| 27 | + hidden_size = recurrent_kernel.shape[0] |
| 28 | + |
| 29 | + if bias is not None: |
| 30 | + input_bias = bias[0] |
| 31 | + recurrent_bias = bias[1] |
| 32 | + else: |
| 33 | + input_bias = np.zeros(3 * hidden_size) |
| 34 | + recurrent_bias = np.zeros(3 * hidden_size) |
| 35 | + |
| 36 | + h = initial_state.copy() |
| 37 | + all_outputs = [] |
| 38 | + |
| 39 | + indices = range(timesteps) |
| 40 | + if go_backwards: |
| 41 | + indices = reversed(indices) |
| 42 | + |
| 43 | + for t in indices: |
| 44 | + x_t = inputs[:, t, :] |
| 45 | + x_all = x_t @ kernel + input_bias |
| 46 | + h_all = h @ recurrent_kernel + recurrent_bias |
| 47 | + |
| 48 | + x_z, x_r, x_h = np.split(x_all, 3, axis=-1) |
| 49 | + h_z, h_r, h_h = np.split(h_all, 3, axis=-1) |
| 50 | + |
| 51 | + z = _np_sigmoid(x_z + h_z) |
| 52 | + r = _np_sigmoid(x_r + h_r) |
| 53 | + hh = _np_tanh(x_h + r * h_h) |
| 54 | + h = z * h + (1 - z) * hh |
| 55 | + all_outputs.append(h.copy()) |
| 56 | + |
| 57 | + if go_backwards: |
| 58 | + all_outputs = list(reversed(all_outputs)) |
| 59 | + |
| 60 | + outputs = np.stack(all_outputs, axis=1) |
| 61 | + last_output = h |
| 62 | + if not return_sequences: |
| 63 | + outputs = last_output[:, np.newaxis, :] |
| 64 | + return last_output, outputs, [h] |
| 65 | + |
| 66 | + |
| 67 | +@pytest.mark.skipif( |
| 68 | + backend.backend() != "jax", |
| 69 | + reason="JAX-specific optimized GRU tests.", |
| 70 | +) |
| 71 | +class JaxGRUTest(testing.TestCase): |
| 72 | + def _get_activations(self): |
| 73 | + """Return JAX-compatible activation functions.""" |
| 74 | + from jax import numpy as jnp |
| 75 | + from jax.nn import sigmoid |
| 76 | + |
| 77 | + return jnp.tanh, sigmoid |
| 78 | + |
| 79 | + def _get_test_weights(self, input_size, hidden_size, use_bias=True): |
| 80 | + rng = np.random.RandomState(42) |
| 81 | + kernel = rng.randn(input_size, 3 * hidden_size).astype("float32") * 0.1 |
| 82 | + recurrent_kernel = ( |
| 83 | + rng.randn(hidden_size, 3 * hidden_size).astype("float32") * 0.1 |
| 84 | + ) |
| 85 | + if use_bias: |
| 86 | + bias = rng.randn(2, 3 * hidden_size).astype("float32") * 0.1 |
| 87 | + else: |
| 88 | + bias = None |
| 89 | + return kernel, recurrent_kernel, bias |
| 90 | + |
| 91 | + def test_forward(self): |
| 92 | + from keras.src.backend.jax.rnn import gru |
| 93 | + |
| 94 | + tanh, sigmoid = self._get_activations() |
| 95 | + batch, seq_len, input_size, hidden_size = 2, 5, 4, 3 |
| 96 | + kernel, recurrent_kernel, bias = self._get_test_weights( |
| 97 | + input_size, hidden_size |
| 98 | + ) |
| 99 | + |
| 100 | + rng = np.random.RandomState(0) |
| 101 | + inputs = rng.randn(batch, seq_len, input_size).astype("float32") |
| 102 | + h_0 = np.zeros((batch, hidden_size), dtype="float32") |
| 103 | + |
| 104 | + last_output, outputs, states = gru( |
| 105 | + inputs, h_0, None, kernel, recurrent_kernel, bias, |
| 106 | + tanh, sigmoid, |
| 107 | + ) |
| 108 | + |
| 109 | + ref_last, _, _ = _gru_reference( |
| 110 | + inputs, h_0, kernel, recurrent_kernel, bias, |
| 111 | + ) |
| 112 | + |
| 113 | + self.assertAllClose(last_output, ref_last, atol=1e-5) |
| 114 | + self.assertEqual(outputs.shape, (batch, 1, hidden_size)) |
| 115 | + |
| 116 | + def test_return_sequences(self): |
| 117 | + from keras.src.backend.jax.rnn import gru |
| 118 | + |
| 119 | + tanh, sigmoid = self._get_activations() |
| 120 | + batch, seq_len, input_size, hidden_size = 2, 5, 4, 3 |
| 121 | + kernel, recurrent_kernel, bias = self._get_test_weights( |
| 122 | + input_size, hidden_size |
| 123 | + ) |
| 124 | + |
| 125 | + rng = np.random.RandomState(0) |
| 126 | + inputs = rng.randn(batch, seq_len, input_size).astype("float32") |
| 127 | + h_0 = np.zeros((batch, hidden_size), dtype="float32") |
| 128 | + |
| 129 | + last_output, outputs, states = gru( |
| 130 | + inputs, h_0, None, kernel, recurrent_kernel, bias, |
| 131 | + tanh, sigmoid, return_sequences=True, |
| 132 | + ) |
| 133 | + |
| 134 | + ref_last, ref_out, _ = _gru_reference( |
| 135 | + inputs, h_0, kernel, recurrent_kernel, bias, |
| 136 | + return_sequences=True, |
| 137 | + ) |
| 138 | + |
| 139 | + self.assertAllClose(last_output, ref_last, atol=1e-5) |
| 140 | + self.assertAllClose(outputs, ref_out, atol=1e-5) |
| 141 | + self.assertEqual(outputs.shape, (batch, seq_len, hidden_size)) |
| 142 | + |
| 143 | + def test_go_backwards(self): |
| 144 | + from keras.src.backend.jax.rnn import gru |
| 145 | + |
| 146 | + tanh, sigmoid = self._get_activations() |
| 147 | + batch, seq_len, input_size, hidden_size = 2, 5, 4, 3 |
| 148 | + kernel, recurrent_kernel, bias = self._get_test_weights( |
| 149 | + input_size, hidden_size |
| 150 | + ) |
| 151 | + |
| 152 | + rng = np.random.RandomState(0) |
| 153 | + inputs = rng.randn(batch, seq_len, input_size).astype("float32") |
| 154 | + h_0 = np.zeros((batch, hidden_size), dtype="float32") |
| 155 | + |
| 156 | + last_output, outputs, _ = gru( |
| 157 | + inputs, h_0, None, kernel, recurrent_kernel, bias, |
| 158 | + tanh, sigmoid, go_backwards=True, |
| 159 | + ) |
| 160 | + |
| 161 | + ref_last, _, _ = _gru_reference( |
| 162 | + inputs, h_0, kernel, recurrent_kernel, bias, |
| 163 | + go_backwards=True, |
| 164 | + ) |
| 165 | + |
| 166 | + self.assertAllClose(last_output, ref_last, atol=1e-5) |
| 167 | + |
| 168 | + def test_go_backwards_return_sequences(self): |
| 169 | + from keras.src.backend.jax.rnn import gru |
| 170 | + |
| 171 | + tanh, sigmoid = self._get_activations() |
| 172 | + batch, seq_len, input_size, hidden_size = 2, 5, 4, 3 |
| 173 | + kernel, recurrent_kernel, bias = self._get_test_weights( |
| 174 | + input_size, hidden_size |
| 175 | + ) |
| 176 | + |
| 177 | + rng = np.random.RandomState(0) |
| 178 | + inputs = rng.randn(batch, seq_len, input_size).astype("float32") |
| 179 | + h_0 = np.zeros((batch, hidden_size), dtype="float32") |
| 180 | + |
| 181 | + last_output, outputs, _ = gru( |
| 182 | + inputs, h_0, None, kernel, recurrent_kernel, bias, |
| 183 | + tanh, sigmoid, go_backwards=True, return_sequences=True, |
| 184 | + ) |
| 185 | + |
| 186 | + ref_last, ref_out, _ = _gru_reference( |
| 187 | + inputs, h_0, kernel, recurrent_kernel, bias, |
| 188 | + go_backwards=True, return_sequences=True, |
| 189 | + ) |
| 190 | + |
| 191 | + self.assertAllClose(last_output, ref_last, atol=1e-5) |
| 192 | + self.assertAllClose(outputs, ref_out, atol=1e-5) |
| 193 | + |
| 194 | + def test_nonzero_initial_state(self): |
| 195 | + from keras.src.backend.jax.rnn import gru |
| 196 | + |
| 197 | + tanh, sigmoid = self._get_activations() |
| 198 | + batch, seq_len, input_size, hidden_size = 2, 5, 4, 3 |
| 199 | + kernel, recurrent_kernel, bias = self._get_test_weights( |
| 200 | + input_size, hidden_size |
| 201 | + ) |
| 202 | + |
| 203 | + rng = np.random.RandomState(0) |
| 204 | + inputs = rng.randn(batch, seq_len, input_size).astype("float32") |
| 205 | + h_0 = rng.randn(batch, hidden_size).astype("float32") * 0.5 |
| 206 | + |
| 207 | + last_output, _, _ = gru( |
| 208 | + inputs, h_0, None, kernel, recurrent_kernel, bias, |
| 209 | + tanh, sigmoid, |
| 210 | + ) |
| 211 | + |
| 212 | + ref_last, _, _ = _gru_reference( |
| 213 | + inputs, h_0, kernel, recurrent_kernel, bias, |
| 214 | + ) |
| 215 | + |
| 216 | + self.assertAllClose(last_output, ref_last, atol=1e-5) |
| 217 | + |
| 218 | + def test_no_bias(self): |
| 219 | + from keras.src.backend.jax.rnn import gru |
| 220 | + |
| 221 | + tanh, sigmoid = self._get_activations() |
| 222 | + batch, seq_len, input_size, hidden_size = 2, 5, 4, 3 |
| 223 | + kernel, recurrent_kernel, _ = self._get_test_weights( |
| 224 | + input_size, hidden_size, use_bias=False |
| 225 | + ) |
| 226 | + |
| 227 | + rng = np.random.RandomState(0) |
| 228 | + inputs = rng.randn(batch, seq_len, input_size).astype("float32") |
| 229 | + h_0 = np.zeros((batch, hidden_size), dtype="float32") |
| 230 | + |
| 231 | + last_output, outputs, _ = gru( |
| 232 | + inputs, h_0, None, kernel, recurrent_kernel, None, |
| 233 | + tanh, sigmoid, return_sequences=True, |
| 234 | + ) |
| 235 | + |
| 236 | + ref_last, ref_out, _ = _gru_reference( |
| 237 | + inputs, h_0, kernel, recurrent_kernel, None, |
| 238 | + return_sequences=True, |
| 239 | + ) |
| 240 | + |
| 241 | + self.assertAllClose(last_output, ref_last, atol=1e-5) |
| 242 | + self.assertAllClose(outputs, ref_out, atol=1e-5) |
| 243 | + |
| 244 | + def test_fallback_reset_after_false(self): |
| 245 | + from keras.src.backend.jax.rnn import gru |
| 246 | + |
| 247 | + tanh, sigmoid = self._get_activations() |
| 248 | + batch, seq_len, input_size, hidden_size = 2, 5, 4, 3 |
| 249 | + kernel, recurrent_kernel, bias = self._get_test_weights( |
| 250 | + input_size, hidden_size |
| 251 | + ) |
| 252 | + |
| 253 | + inputs = np.zeros((batch, seq_len, input_size), dtype="float32") |
| 254 | + h_0 = np.zeros((batch, hidden_size), dtype="float32") |
| 255 | + |
| 256 | + with self.assertRaises(NotImplementedError): |
| 257 | + gru( |
| 258 | + inputs, h_0, None, kernel, recurrent_kernel, bias, |
| 259 | + tanh, sigmoid, reset_after=False, |
| 260 | + ) |
| 261 | + |
| 262 | + def test_fallback_unroll(self): |
| 263 | + from keras.src.backend.jax.rnn import gru |
| 264 | + |
| 265 | + tanh, sigmoid = self._get_activations() |
| 266 | + batch, seq_len, input_size, hidden_size = 2, 5, 4, 3 |
| 267 | + kernel, recurrent_kernel, bias = self._get_test_weights( |
| 268 | + input_size, hidden_size |
| 269 | + ) |
| 270 | + |
| 271 | + inputs = np.zeros((batch, seq_len, input_size), dtype="float32") |
| 272 | + h_0 = np.zeros((batch, hidden_size), dtype="float32") |
| 273 | + |
| 274 | + with self.assertRaises(NotImplementedError): |
| 275 | + gru( |
| 276 | + inputs, h_0, None, kernel, recurrent_kernel, bias, |
| 277 | + tanh, sigmoid, unroll=True, |
| 278 | + ) |
| 279 | + |
| 280 | + def test_fallback_mask(self): |
| 281 | + from keras.src.backend.jax.rnn import gru |
| 282 | + |
| 283 | + tanh, sigmoid = self._get_activations() |
| 284 | + batch, seq_len, input_size, hidden_size = 2, 5, 4, 3 |
| 285 | + kernel, recurrent_kernel, bias = self._get_test_weights( |
| 286 | + input_size, hidden_size |
| 287 | + ) |
| 288 | + |
| 289 | + inputs = np.zeros((batch, seq_len, input_size), dtype="float32") |
| 290 | + h_0 = np.zeros((batch, hidden_size), dtype="float32") |
| 291 | + mask = np.ones((batch, seq_len), dtype="bool") |
| 292 | + |
| 293 | + with self.assertRaises(NotImplementedError): |
| 294 | + gru( |
| 295 | + inputs, h_0, mask, kernel, recurrent_kernel, bias, |
| 296 | + tanh, sigmoid, |
| 297 | + ) |
| 298 | + |
| 299 | + def test_matches_layer_output(self): |
| 300 | + """Verify the optimized path matches the layer's output.""" |
| 301 | + from keras.src import initializers |
| 302 | + from keras.src import layers |
| 303 | + |
| 304 | + sequence = np.arange(72).reshape((3, 6, 4)).astype("float32") |
| 305 | + layer = layers.GRU( |
| 306 | + 3, |
| 307 | + kernel_initializer=initializers.Constant(0.01), |
| 308 | + recurrent_initializer=initializers.Constant(0.02), |
| 309 | + bias_initializer=initializers.Constant(0.03), |
| 310 | + ) |
| 311 | + layer_output = layer(sequence) |
| 312 | + |
| 313 | + self.assertAllClose( |
| 314 | + np.array( |
| 315 | + [ |
| 316 | + [0.5217289, 0.5217289, 0.5217289], |
| 317 | + [0.6371659, 0.6371659, 0.6371659], |
| 318 | + [0.39384964, 0.39384964, 0.3938496], |
| 319 | + ] |
| 320 | + ), |
| 321 | + layer_output, |
| 322 | + atol=1e-5, |
| 323 | + ) |
| 324 | + |
| 325 | + def test_matches_layer_go_backwards(self): |
| 326 | + from keras.src import initializers |
| 327 | + from keras.src import layers |
| 328 | + |
| 329 | + sequence = np.arange(72).reshape((3, 6, 4)).astype("float32") |
| 330 | + layer = layers.GRU( |
| 331 | + 3, |
| 332 | + kernel_initializer=initializers.Constant(0.01), |
| 333 | + recurrent_initializer=initializers.Constant(0.02), |
| 334 | + bias_initializer=initializers.Constant(0.03), |
| 335 | + go_backwards=True, |
| 336 | + ) |
| 337 | + layer_output = layer(sequence) |
| 338 | + |
| 339 | + self.assertAllClose( |
| 340 | + np.array( |
| 341 | + [ |
| 342 | + [0.24406259, 0.24406259, 0.24406259], |
| 343 | + [0.611516, 0.611516, 0.611516], |
| 344 | + [0.3928808, 0.3928808, 0.3928808], |
| 345 | + ] |
| 346 | + ), |
| 347 | + layer_output, |
| 348 | + atol=1e-5, |
| 349 | + ) |
0 commit comments