Skip to content

Commit 9e19f82

Browse files
committed
Add cuDNN LSTM for JAX backend
1 parent 69f9311 commit 9e19f82

File tree

1 file changed

+145
-4
lines changed

1 file changed

+145
-4
lines changed

keras/src/backend/jax/rnn.py

Lines changed: 145 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -211,12 +211,153 @@ def _step(states, current_input):
211211
return last_output, outputs, new_states
212212

213213

214-
def cudnn_ok(*args, **kwargs):
215-
return False
214+
def _is_gpu_available():
215+
import jax
216216

217+
return any(d.platform == "gpu" for d in jax.devices())
217218

218-
def lstm(*args, **kwargs):
219-
raise NotImplementedError
219+
220+
def cudnn_ok(
221+
activation,
222+
recurrent_activation,
223+
unroll,
224+
use_bias=True,
225+
):
226+
from keras.src import activations
227+
from keras.src import ops
228+
229+
return (
230+
activation in (activations.tanh, jnp.tanh, ops.tanh)
231+
and recurrent_activation in (activations.sigmoid, ops.sigmoid) # noqa: E501
232+
and not unroll
233+
and use_bias
234+
and _is_gpu_available()
235+
)
236+
237+
238+
def _assert_valid_mask(mask):
239+
max_seq_length = mask.shape[1]
240+
count_of_true = jnp.sum(mask.astype(jnp.int32), axis=1)
241+
indices = jnp.broadcast_to(jnp.arange(max_seq_length), mask.shape)
242+
right_padded_mask = indices < count_of_true[:, None]
243+
is_right_padded = jnp.all(mask == right_padded_mask)
244+
has_fully_masked = jnp.any(jnp.all(~mask, axis=1))
245+
246+
if not (is_right_padded & ~has_fully_masked):
247+
raise ValueError(
248+
"You are passing a RNN mask that does not correspond to "
249+
"right-padded sequences, while using cuDNN, which is not "
250+
"supported. With cuDNN, RNN masks can only be used for "
251+
"right-padding, e.g. `[[True, True, False, False]]` would "
252+
"be a valid mask, but any mask that isn't just contiguous "
253+
"`True`'s on the left and contiguous `False`'s on the right "
254+
"would be invalid. You can pass `use_cudnn=False` to your "
255+
"RNN layer to stop using cuDNN (this may be slower)."
256+
)
257+
258+
259+
def lstm(
260+
inputs,
261+
initial_state_h,
262+
initial_state_c,
263+
mask,
264+
kernel,
265+
recurrent_kernel,
266+
bias,
267+
activation,
268+
recurrent_activation,
269+
return_sequences=False,
270+
go_backwards=False,
271+
unroll=False,
272+
):
273+
if not cudnn_ok(
274+
activation,
275+
recurrent_activation,
276+
unroll,
277+
use_bias=bias is not None,
278+
):
279+
raise NotImplementedError
280+
281+
try:
282+
from jax.experimental.rnn import lstm as jax_lstm
283+
except ImportError:
284+
raise NotImplementedError
285+
286+
input_size = kernel.shape[0]
287+
hidden_size = recurrent_kernel.shape[0]
288+
batch_size = inputs.shape[0]
289+
290+
# Transpose Keras kernels to cuDNN layout and flatten.
291+
# Gate order [i, f, c, o] matches cuDNN [i, f, g, o].
292+
W_ih = jnp.asarray(kernel).T
293+
W_hh = jnp.asarray(recurrent_kernel).T
294+
295+
if bias is not None:
296+
b_ih = jnp.asarray(bias)
297+
else:
298+
b_ih = jnp.zeros(4 * hidden_size)
299+
b_hh = jnp.zeros_like(b_ih)
300+
301+
# cuDNN flat weight order: [W_ih, W_hh, b_ih, b_hh]
302+
weights = jnp.concatenate(
303+
[W_ih.ravel(), W_hh.ravel(), b_ih.ravel(), b_hh.ravel()]
304+
)
305+
306+
# cuDNN expects (num_layers * num_directions, batch, hidden)
307+
h_0 = jnp.asarray(initial_state_h)
308+
c_0 = jnp.asarray(initial_state_c)
309+
if h_0.ndim == 2:
310+
h_0 = h_0[jnp.newaxis]
311+
c_0 = c_0[jnp.newaxis]
312+
313+
if go_backwards:
314+
inputs = jnp.flip(inputs, axis=1)
315+
if mask is not None:
316+
mask = jnp.flip(mask, axis=1)
317+
318+
if mask is not None:
319+
mask = jnp.asarray(mask).astype(jnp.bool_)
320+
if mask.ndim == 3:
321+
mask = mask[:, :, 0]
322+
_assert_valid_mask(mask)
323+
seq_lengths = jnp.sum(mask.astype(jnp.int32), axis=1)
324+
else:
325+
seq_lengths = jnp.full((batch_size,), inputs.shape[1], dtype=jnp.int32)
326+
327+
try:
328+
y, h_n, c_n = jax_lstm(
329+
inputs,
330+
h_0,
331+
c_0,
332+
weights,
333+
seq_lengths,
334+
input_size=input_size,
335+
hidden_size=hidden_size,
336+
num_layers=1,
337+
dropout=0.0,
338+
bidirectional=False,
339+
)
340+
except Exception:
341+
raise NotImplementedError
342+
343+
# y: (batch, seq_len, hidden), h_n/c_n: (1, batch, hidden)
344+
h_n = h_n.squeeze(0)
345+
c_n = c_n.squeeze(0)
346+
347+
if mask is not None:
348+
last_output = h_n
349+
else:
350+
last_output = y[:, -1]
351+
352+
if not return_sequences:
353+
outputs = last_output[:, jnp.newaxis, :]
354+
else:
355+
outputs = y
356+
357+
if go_backwards and return_sequences:
358+
outputs = jnp.flip(outputs, axis=1)
359+
360+
return last_output, outputs, [h_n, c_n]
220361

221362

222363
def gru(*args, **kwargs):

0 commit comments

Comments
 (0)