Skip to content

Commit d2f3497

Browse files
authored
Updated demo to work with newer version of JAX
Since new versions of JAX don't support the abstraction using `__jax_array__` anymore, the old method of passing `keras.Variables` to `jax.jit` compiled functions doesn't work. This change fixes that by manually extracting the underlying jax arrays
1 parent 76da690 commit d2f3497

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

guides/writing_a_custom_training_loop_in_jax.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -277,9 +277,9 @@ def train_step(state, data):
277277
# Build optimizer variables.
278278
optimizer.build(model.trainable_variables)
279279

280-
trainable_variables = model.trainable_variables
281-
non_trainable_variables = model.non_trainable_variables
282-
optimizer_variables = optimizer.variables
280+
trainable_variables = [v.value for v in model.trainable_variables]
281+
non_trainable_variables = [v.value for v in model.non_trainable_variables]
282+
optimizer_variables = [v.value for v in optimizer.variables]
283283
state = trainable_variables, non_trainable_variables, optimizer_variables
284284

285285
# Training loop

0 commit comments

Comments
 (0)