@@ -211,12 +211,153 @@ def _step(states, current_input):
211211 return last_output , outputs , new_states
212212
213213
214- def cudnn_ok ( * args , ** kwargs ):
215- return False
214+ def _is_gpu_available ( ):
215+ import jax
216216
217+ return any (d .platform == "gpu" for d in jax .devices ())
217218
218- def lstm (* args , ** kwargs ):
219- raise NotImplementedError
219+
220+ def cudnn_ok (
221+ activation ,
222+ recurrent_activation ,
223+ unroll ,
224+ use_bias = True ,
225+ ):
226+ from keras .src import activations
227+ from keras .src import ops
228+
229+ return (
230+ activation in (activations .tanh , jnp .tanh , ops .tanh )
231+ and recurrent_activation in (activations .sigmoid , ops .sigmoid ) # noqa: E501
232+ and not unroll
233+ and use_bias
234+ and _is_gpu_available ()
235+ )
236+
237+
238+ def _assert_valid_mask (mask ):
239+ max_seq_length = mask .shape [1 ]
240+ count_of_true = jnp .sum (mask .astype (jnp .int32 ), axis = 1 )
241+ indices = jnp .broadcast_to (jnp .arange (max_seq_length ), mask .shape )
242+ right_padded_mask = indices < count_of_true [:, None ]
243+ is_right_padded = jnp .all (mask == right_padded_mask )
244+ has_fully_masked = jnp .any (jnp .all (~ mask , axis = 1 ))
245+
246+ if not (is_right_padded & ~ has_fully_masked ):
247+ raise ValueError (
248+ "You are passing a RNN mask that does not correspond to "
249+ "right-padded sequences, while using cuDNN, which is not "
250+ "supported. With cuDNN, RNN masks can only be used for "
251+ "right-padding, e.g. `[[True, True, False, False]]` would "
252+ "be a valid mask, but any mask that isn't just contiguous "
253+ "`True`'s on the left and contiguous `False`'s on the right "
254+ "would be invalid. You can pass `use_cudnn=False` to your "
255+ "RNN layer to stop using cuDNN (this may be slower)."
256+ )
257+
258+
259+ def lstm (
260+ inputs ,
261+ initial_state_h ,
262+ initial_state_c ,
263+ mask ,
264+ kernel ,
265+ recurrent_kernel ,
266+ bias ,
267+ activation ,
268+ recurrent_activation ,
269+ return_sequences = False ,
270+ go_backwards = False ,
271+ unroll = False ,
272+ ):
273+ if not cudnn_ok (
274+ activation ,
275+ recurrent_activation ,
276+ unroll ,
277+ use_bias = bias is not None ,
278+ ):
279+ raise NotImplementedError
280+
281+ try :
282+ from jax .experimental .rnn import lstm as jax_lstm
283+ except ImportError :
284+ raise NotImplementedError
285+
286+ input_size = kernel .shape [0 ]
287+ hidden_size = recurrent_kernel .shape [0 ]
288+ batch_size = inputs .shape [0 ]
289+
290+ # Transpose Keras kernels to cuDNN layout and flatten.
291+ # Gate order [i, f, c, o] matches cuDNN [i, f, g, o].
292+ W_ih = jnp .asarray (kernel ).T
293+ W_hh = jnp .asarray (recurrent_kernel ).T
294+
295+ if bias is not None :
296+ b_ih = jnp .asarray (bias )
297+ else :
298+ b_ih = jnp .zeros (4 * hidden_size )
299+ b_hh = jnp .zeros_like (b_ih )
300+
301+ # cuDNN flat weight order: [W_ih, W_hh, b_ih, b_hh]
302+ weights = jnp .concatenate (
303+ [W_ih .ravel (), W_hh .ravel (), b_ih .ravel (), b_hh .ravel ()]
304+ )
305+
306+ # cuDNN expects (num_layers * num_directions, batch, hidden)
307+ h_0 = jnp .asarray (initial_state_h )
308+ c_0 = jnp .asarray (initial_state_c )
309+ if h_0 .ndim == 2 :
310+ h_0 = h_0 [jnp .newaxis ]
311+ c_0 = c_0 [jnp .newaxis ]
312+
313+ if go_backwards :
314+ inputs = jnp .flip (inputs , axis = 1 )
315+ if mask is not None :
316+ mask = jnp .flip (mask , axis = 1 )
317+
318+ if mask is not None :
319+ mask = jnp .asarray (mask ).astype (jnp .bool_ )
320+ if mask .ndim == 3 :
321+ mask = mask [:, :, 0 ]
322+ _assert_valid_mask (mask )
323+ seq_lengths = jnp .sum (mask .astype (jnp .int32 ), axis = 1 )
324+ else :
325+ seq_lengths = jnp .full ((batch_size ,), inputs .shape [1 ], dtype = jnp .int32 )
326+
327+ try :
328+ y , h_n , c_n = jax_lstm (
329+ inputs ,
330+ h_0 ,
331+ c_0 ,
332+ weights ,
333+ seq_lengths ,
334+ input_size = input_size ,
335+ hidden_size = hidden_size ,
336+ num_layers = 1 ,
337+ dropout = 0.0 ,
338+ bidirectional = False ,
339+ )
340+ except Exception :
341+ raise NotImplementedError
342+
343+ # y: (batch, seq_len, hidden), h_n/c_n: (1, batch, hidden)
344+ h_n = h_n .squeeze (0 )
345+ c_n = c_n .squeeze (0 )
346+
347+ if mask is not None :
348+ last_output = h_n
349+ else :
350+ last_output = y [:, - 1 ]
351+
352+ if not return_sequences :
353+ outputs = last_output [:, jnp .newaxis , :]
354+ else :
355+ outputs = y
356+
357+ if go_backwards and return_sequences :
358+ outputs = jnp .flip (outputs , axis = 1 )
359+
360+ return last_output , outputs , [h_n , c_n ]
220361
221362
222363def gru (* args , ** kwargs ):
0 commit comments