Skip to content

Commit 7061b2f

Browse files
committed
Add cuDNN LSTM for JAX backend
1 parent e65cfb8 commit 7061b2f

File tree

1 file changed

+148
-4
lines changed

1 file changed

+148
-4
lines changed

keras/src/backend/jax/rnn.py

Lines changed: 148 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -211,12 +211,156 @@ def _step(states, current_input):
211211
return last_output, outputs, new_states
212212

213213

214-
def cudnn_ok(*args, **kwargs):
215-
return False
214+
def _is_gpu_available():
215+
import jax
216216

217+
return any(d.platform == "gpu" for d in jax.devices())
217218

218-
def lstm(*args, **kwargs):
219-
raise NotImplementedError
219+
220+
def cudnn_ok(
221+
activation,
222+
recurrent_activation,
223+
unroll,
224+
use_bias=True,
225+
):
226+
from keras.src import activations
227+
from keras.src import ops
228+
229+
return (
230+
activation in (activations.tanh, jnp.tanh, ops.tanh)
231+
and recurrent_activation
232+
in (activations.sigmoid, ops.sigmoid) # noqa: E501
233+
and not unroll
234+
and use_bias
235+
and _is_gpu_available()
236+
)
237+
238+
239+
def _assert_valid_mask(mask):
240+
max_seq_length = mask.shape[1]
241+
count_of_true = jnp.sum(mask.astype(jnp.int32), axis=1)
242+
indices = jnp.broadcast_to(
243+
jnp.arange(max_seq_length), mask.shape
244+
)
245+
right_padded_mask = indices < count_of_true[:, None]
246+
is_right_padded = jnp.all(mask == right_padded_mask)
247+
has_fully_masked = jnp.any(jnp.all(~mask, axis=1))
248+
249+
if not (is_right_padded & ~has_fully_masked):
250+
raise ValueError(
251+
"You are passing a RNN mask that does not correspond to "
252+
"right-padded sequences, while using cuDNN, which is not "
253+
"supported. With cuDNN, RNN masks can only be used for "
254+
"right-padding, e.g. `[[True, True, False, False]]` would "
255+
"be a valid mask, but any mask that isn't just contiguous "
256+
"`True`'s on the left and contiguous `False`'s on the right "
257+
"would be invalid. You can pass `use_cudnn=False` to your "
258+
"RNN layer to stop using cuDNN (this may be slower)."
259+
)
260+
261+
262+
def lstm(
263+
inputs,
264+
initial_state_h,
265+
initial_state_c,
266+
mask,
267+
kernel,
268+
recurrent_kernel,
269+
bias,
270+
activation,
271+
recurrent_activation,
272+
return_sequences=False,
273+
go_backwards=False,
274+
unroll=False,
275+
):
276+
if not cudnn_ok(
277+
activation,
278+
recurrent_activation,
279+
unroll,
280+
use_bias=bias is not None,
281+
):
282+
raise NotImplementedError
283+
284+
try:
285+
from jax.experimental.rnn import lstm as jax_lstm
286+
except ImportError:
287+
raise NotImplementedError
288+
289+
input_size = kernel.shape[0]
290+
hidden_size = recurrent_kernel.shape[0]
291+
batch_size = inputs.shape[0]
292+
293+
# Transpose Keras kernels to cuDNN layout and flatten.
294+
# Gate order [i, f, c, o] matches cuDNN [i, f, g, o].
295+
W_ih = jnp.asarray(kernel).T
296+
W_hh = jnp.asarray(recurrent_kernel).T
297+
298+
if bias is not None:
299+
b_ih = jnp.asarray(bias)
300+
else:
301+
b_ih = jnp.zeros(4 * hidden_size)
302+
b_hh = jnp.zeros_like(b_ih)
303+
304+
# cuDNN flat weight order: [W_ih, W_hh, b_ih, b_hh]
305+
weights = jnp.concatenate(
306+
[W_ih.ravel(), W_hh.ravel(), b_ih.ravel(), b_hh.ravel()]
307+
)
308+
309+
# cuDNN expects (num_layers * num_directions, batch, hidden)
310+
h_0 = jnp.asarray(initial_state_h)
311+
c_0 = jnp.asarray(initial_state_c)
312+
if h_0.ndim == 2:
313+
h_0 = h_0[jnp.newaxis]
314+
c_0 = c_0[jnp.newaxis]
315+
316+
if go_backwards:
317+
inputs = jnp.flip(inputs, axis=1)
318+
if mask is not None:
319+
mask = jnp.flip(mask, axis=1)
320+
321+
if mask is not None:
322+
mask = jnp.asarray(mask).astype(jnp.bool_)
323+
if mask.ndim == 3:
324+
mask = mask[:, :, 0]
325+
_assert_valid_mask(mask)
326+
seq_lengths = jnp.sum(mask.astype(jnp.int32), axis=1)
327+
else:
328+
seq_lengths = jnp.full((batch_size,), inputs.shape[1], dtype=jnp.int32)
329+
330+
try:
331+
y, h_n, c_n = jax_lstm(
332+
inputs,
333+
h_0,
334+
c_0,
335+
weights,
336+
seq_lengths,
337+
input_size=input_size,
338+
hidden_size=hidden_size,
339+
num_layers=1,
340+
dropout=0.0,
341+
bidirectional=False,
342+
)
343+
except Exception:
344+
raise NotImplementedError
345+
346+
# y: (batch, seq_len, hidden), h_n/c_n: (1, batch, hidden)
347+
h_n = h_n.squeeze(0)
348+
c_n = c_n.squeeze(0)
349+
350+
if mask is not None:
351+
last_output = h_n
352+
else:
353+
last_output = y[:, -1]
354+
355+
if not return_sequences:
356+
outputs = last_output[:, jnp.newaxis, :]
357+
else:
358+
outputs = y
359+
360+
if go_backwards and return_sequences:
361+
outputs = jnp.flip(outputs, axis=1)
362+
363+
return last_output, outputs, [h_n, c_n]
220364

221365

222366
def gru(*args, **kwargs):

0 commit comments

Comments
 (0)