diff --git a/guides/writing_a_custom_training_loop_in_jax.py b/guides/writing_a_custom_training_loop_in_jax.py index e6ef9889cd..06c7cc9542 100644 --- a/guides/writing_a_custom_training_loop_in_jax.py +++ b/guides/writing_a_custom_training_loop_in_jax.py @@ -277,9 +277,9 @@ def train_step(state, data): # Build optimizer variables. optimizer.build(model.trainable_variables) -trainable_variables = model.trainable_variables -non_trainable_variables = model.non_trainable_variables -optimizer_variables = optimizer.variables +trainable_variables = [v.value for v in model.trainable_variables] +non_trainable_variables = [v.value for v in model.non_trainable_variables] +optimizer_variables = [v.value for v in optimizer.variables] state = trainable_variables, non_trainable_variables, optimizer_variables # Training loop