@@ -211,12 +211,156 @@ def _step(states, current_input):
211211 return last_output , outputs , new_states
212212
213213
214- def cudnn_ok ( * args , ** kwargs ):
215- return False
214+ def _is_gpu_available ( ):
215+ import jax
216216
217+ return any (d .platform == "gpu" for d in jax .devices ())
217218
218- def lstm (* args , ** kwargs ):
219- raise NotImplementedError
219+
220+ def cudnn_ok (
221+ activation ,
222+ recurrent_activation ,
223+ unroll ,
224+ use_bias = True ,
225+ ):
226+ from keras .src import activations
227+ from keras .src import ops
228+
229+ return (
230+ activation in (activations .tanh , jnp .tanh , ops .tanh )
231+ and recurrent_activation
232+ in (activations .sigmoid , ops .sigmoid ) # noqa: E501
233+ and not unroll
234+ and use_bias
235+ and _is_gpu_available ()
236+ )
237+
238+
239+ def _assert_valid_mask (mask ):
240+ max_seq_length = mask .shape [1 ]
241+ count_of_true = jnp .sum (mask .astype (jnp .int32 ), axis = 1 )
242+ indices = jnp .broadcast_to (
243+ jnp .arange (max_seq_length ), mask .shape
244+ )
245+ right_padded_mask = indices < count_of_true [:, None ]
246+ is_right_padded = jnp .all (mask == right_padded_mask )
247+ has_fully_masked = jnp .any (jnp .all (~ mask , axis = 1 ))
248+
249+ if not (is_right_padded & ~ has_fully_masked ):
250+ raise ValueError (
251+ "You are passing a RNN mask that does not correspond to "
252+ "right-padded sequences, while using cuDNN, which is not "
253+ "supported. With cuDNN, RNN masks can only be used for "
254+ "right-padding, e.g. `[[True, True, False, False]]` would "
255+ "be a valid mask, but any mask that isn't just contiguous "
256+ "`True`'s on the left and contiguous `False`'s on the right "
257+ "would be invalid. You can pass `use_cudnn=False` to your "
258+ "RNN layer to stop using cuDNN (this may be slower)."
259+ )
260+
261+
262+ def lstm (
263+ inputs ,
264+ initial_state_h ,
265+ initial_state_c ,
266+ mask ,
267+ kernel ,
268+ recurrent_kernel ,
269+ bias ,
270+ activation ,
271+ recurrent_activation ,
272+ return_sequences = False ,
273+ go_backwards = False ,
274+ unroll = False ,
275+ ):
276+ if not cudnn_ok (
277+ activation ,
278+ recurrent_activation ,
279+ unroll ,
280+ use_bias = bias is not None ,
281+ ):
282+ raise NotImplementedError
283+
284+ try :
285+ from jax .experimental .rnn import lstm as jax_lstm
286+ except ImportError :
287+ raise NotImplementedError
288+
289+ input_size = kernel .shape [0 ]
290+ hidden_size = recurrent_kernel .shape [0 ]
291+ batch_size = inputs .shape [0 ]
292+
293+ # Transpose Keras kernels to cuDNN layout and flatten.
294+ # Gate order [i, f, c, o] matches cuDNN [i, f, g, o].
295+ W_ih = jnp .asarray (kernel ).T
296+ W_hh = jnp .asarray (recurrent_kernel ).T
297+
298+ if bias is not None :
299+ b_ih = jnp .asarray (bias )
300+ else :
301+ b_ih = jnp .zeros (4 * hidden_size )
302+ b_hh = jnp .zeros_like (b_ih )
303+
304+ # cuDNN flat weight order: [W_ih, W_hh, b_ih, b_hh]
305+ weights = jnp .concatenate (
306+ [W_ih .ravel (), W_hh .ravel (), b_ih .ravel (), b_hh .ravel ()]
307+ )
308+
309+ # cuDNN expects (num_layers * num_directions, batch, hidden)
310+ h_0 = jnp .asarray (initial_state_h )
311+ c_0 = jnp .asarray (initial_state_c )
312+ if h_0 .ndim == 2 :
313+ h_0 = h_0 [jnp .newaxis ]
314+ c_0 = c_0 [jnp .newaxis ]
315+
316+ if go_backwards :
317+ inputs = jnp .flip (inputs , axis = 1 )
318+ if mask is not None :
319+ mask = jnp .flip (mask , axis = 1 )
320+
321+ if mask is not None :
322+ mask = jnp .asarray (mask ).astype (jnp .bool_ )
323+ if mask .ndim == 3 :
324+ mask = mask [:, :, 0 ]
325+ _assert_valid_mask (mask )
326+ seq_lengths = jnp .sum (mask .astype (jnp .int32 ), axis = 1 )
327+ else :
328+ seq_lengths = jnp .full ((batch_size ,), inputs .shape [1 ], dtype = jnp .int32 )
329+
330+ try :
331+ y , h_n , c_n = jax_lstm (
332+ inputs ,
333+ h_0 ,
334+ c_0 ,
335+ weights ,
336+ seq_lengths ,
337+ input_size = input_size ,
338+ hidden_size = hidden_size ,
339+ num_layers = 1 ,
340+ dropout = 0.0 ,
341+ bidirectional = False ,
342+ )
343+ except Exception :
344+ raise NotImplementedError
345+
346+ # y: (batch, seq_len, hidden), h_n/c_n: (1, batch, hidden)
347+ h_n = h_n .squeeze (0 )
348+ c_n = c_n .squeeze (0 )
349+
350+ if mask is not None :
351+ last_output = h_n
352+ else :
353+ last_output = y [:, - 1 ]
354+
355+ if not return_sequences :
356+ outputs = last_output [:, jnp .newaxis , :]
357+ else :
358+ outputs = y
359+
360+ if go_backwards and return_sequences :
361+ outputs = jnp .flip (outputs , axis = 1 )
362+
363+ return last_output , outputs , [h_n , c_n ]
220364
221365
222366def gru (* args , ** kwargs ):
0 commit comments