Commit fd98be5
committed
Pin optax<0.2.7 to avoid tree flatten_up_to regression
optax 0.2.8 has a bug where its optimizer state contains None entries
that jax.tree.map's flatten_up_to no longer accepts as tree prefixes,
breaking gradient updates in test_hf_gpt2_serialize.1 parent ab71e00 commit fd98be5
2 files changed
Lines changed: 6 additions & 5 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
39 | 39 | | |
40 | 40 | | |
41 | 41 | | |
42 | | - | |
| 42 | + | |
43 | 43 | | |
44 | 44 | | |
45 | 45 | | |
| |||
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
0 commit comments