Commit d2f3497
authored
Updated demo to work with newer version of JAX
Since new versions of JAX don't support the abstraction using `__jax_array__` anymore, the old method of passing `keras.Variables` to `jax.jit` compiled functions doesn't work.
This change fixes that by manually extracting the underlying jax arrays1 parent 76da690 commit d2f3497
1 file changed
+3
-3
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
277 | 277 | | |
278 | 278 | | |
279 | 279 | | |
280 | | - | |
281 | | - | |
282 | | - | |
| 280 | + | |
| 281 | + | |
| 282 | + | |
283 | 283 | | |
284 | 284 | | |
285 | 285 | | |
| |||
0 commit comments