Skip to content

Commit fd98be5

Browse files
committed
Pin optax<0.2.7 to avoid tree flatten_up_to regression
optax 0.2.8 has a bug where its optimizer state contains None entries that jax.tree.map's flatten_up_to no longer accepts as tree prefixes, breaking gradient updates in test_hf_gpt2_serialize.
1 parent ab71e00 commit fd98be5

2 files changed

Lines changed: 6 additions & 5 deletions

File tree

lib/levanter/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ dependencies = [
3939
"tokenizers>=0.15.2",
4040
"transformers>=4.57.1,<5.0",
4141
"chex>=0.1.86",
42-
"optax>=0.1.9",
42+
"optax>=0.1.9,<0.2.7",
4343
"wandb>=0.17.8",
4444
"draccus>=0.11.5",
4545
"pyarrow>=11.0.0",

uv.lock

Lines changed: 5 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)