Skip to content

Commit a1e0296

Browse files
committed
Loosen jax requirement, fix small jax-related bugs
Signed-off-by: Fabrice Normandin <[email protected]>
1 parent e47eded commit a1e0296

File tree

6 files changed

+2409
-1829
lines changed

6 files changed

+2409
-1829
lines changed

project/algorithms/jax_ppo.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@
4040

4141
logger = get_logger(__name__)
4242

43+
# Little compatibility patch for rejax.
44+
if not hasattr(jax, "tree_map"):
45+
setattr(jax, "tree_map", jax.tree.map)
46+
4347
TEnvParams = TypeVar("TEnvParams", bound=gymnax.EnvParams, default=gymnax.EnvParams)
4448
"""Type variable for the env params (`gymnax.EnvParams`)."""
4549

project/algorithms/jax_ppo_test.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -157,9 +157,7 @@ def test_ours(
157157
original_datadir: Path,
158158
):
159159
ts, evaluations = results_ours
160-
tensor_regression.check(
161-
jax.tree.map(operator.methodcaller("__array__"), dataclasses.asdict(evaluations))
162-
)
160+
tensor_regression.check(jax.tree.map(np.asarray, dataclasses.asdict(evaluations)))
163161

164162
eval_rng = rng
165163
if isinstance(seed, int):

project/utils/remote_launcher_plugin_test.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,13 @@ def test_can_load_configs(command_line_args: str):
9292
# assert isinstance(launcher, remote_launcher_plugin.RemoteSlurmLauncher)
9393
else:
9494
launcher = hydra.utils.instantiate(launcher_config)
95-
assert isinstance(launcher, SlurmLauncher)
95+
# bug: Seems to be some weird reloading of classes happening, causing this test to
96+
# fail when comparing two classes, which have the same name and a probably the same,
97+
# but loaded twice?
98+
assert (
99+
isinstance(launcher, SlurmLauncher)
100+
or type(launcher).__name__ == SlurmLauncher.__name__
101+
)
96102

97103

98104
in_github_CI = os.environ.get("GITHUB_ACTIONS") == "true"

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ dependencies = [
2323
"transformers>=4.44.0",
2424
"datasets>=2.21.0",
2525
# Jax-related dependencies:
26-
"jax>=0.6.0",
26+
"jax",
2727
"torch-jax-interop>=0.0.8",
2828
"gymnax @ git+https://www.github.com/lebrice/gymnax@fix-classic-control-rendering",
2929
"rejax>=0.1.0",

pyproject.toml.jinja

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,8 @@ dependencies = [
2323
"transformers>=4.44.0",
2424
"datasets>=2.21.0",
2525
# Jax-related dependencies:
26-
"jax==0.4.33",
27-
"jaxlib==0.4.33",
28-
"torch-jax-interop>=0.0.7",
26+
"jax",
27+
"torch-jax-interop>=0.0.8",
2928
"gymnax @ git+https://www.github.com/lebrice/gymnax@fix-classic-control-rendering",
3029
"rejax>=0.1.0",
3130
"xtils[jitpp] @ git+https://github.com/jessefarebro/xtils",
@@ -37,7 +36,7 @@ dependencies = [
3736
"hydra-orion-sweeper>=1.6.4",
3837
]
3938
readme = "README.md"
40-
requires-python = ">= {{python_version}}"
39+
requires-python = ">= {{python_version}},<3.14"
4140
4241
[dependency-groups]
4342
dev = [

0 commit comments

Comments
 (0)