Skip to content

Commit e6d12a5

Browse files
authored
Upgrade to Python 3.11 (#139)
* Upgrade to Python 3.11 Signed-off-by: Fabrice Normandin <[email protected]> * Remove block for Orion in python>=3.11 Signed-off-by: Fabrice Normandin <[email protected]> * Upgrade torch-jax-interop and jax deps Signed-off-by: Fabrice Normandin <[email protected]> * Fix pre-commit issue with py311 Signed-off-by: Fabrice Normandin <[email protected]> * Loosen jax requirement, fix small jax-related bugs Signed-off-by: Fabrice Normandin <[email protected]> * Update mkdocs.yml following mkdocs change Signed-off-by: Fabrice Normandin <[email protected]> * Fix jax extra version in templated pyproject.toml Signed-off-by: Fabrice Normandin <[email protected]> * Upgrade tensor_regression to fix bug in rl test Signed-off-by: Fabrice Normandin <[email protected]> * Fix bug in jax_ppo_test.py Signed-off-by: Fabrice Normandin <[email protected]> * Fix default python version used in int. tests Signed-off-by: Fabrice Normandin <[email protected]> * Jaxlib isn't installed explicitly, dont remove it Signed-off-by: Fabrice Normandin <[email protected]> * Upgrade pytorch version, use pytorch index Signed-off-by: Fabrice Normandin <[email protected]> * Also upgrade pytorch in templated pyproject file Signed-off-by: Fabrice Normandin <[email protected]> * Make flax and tomli dependencies explicit Signed-off-by: Fabrice Normandin <[email protected]> --------- Signed-off-by: Fabrice Normandin <[email protected]>
1 parent b8d205c commit e6d12a5

File tree

12 files changed

+3081
-2886
lines changed

12 files changed

+3081
-2886
lines changed

.github/workflows/build.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ jobs:
6060
max-parallel: 4
6161
matrix:
6262
platform: [ubuntu-latest, macos-latest]
63-
python-version: ["3.10"]
63+
python-version: ["3.11"]
6464
steps:
6565
- uses: actions/checkout@v4
6666
- name: Install the latest version of uv

.python-version

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
3.10.14
1+
3.11

copier.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ _tasks:
255255
## Jax-related dependencies:
256256
- command: uv remove rejax gymnax gymnasium xtils --no-sync
257257
when: "{{ run_setup and 'jax_ppo' not in examples_to_include }}"
258-
- command: uv remove jax jaxlib torch-jax-interop --no-sync
258+
- command: uv remove jax torch-jax-interop flax --no-sync
259259
when: "{{ run_setup and 'jax_ppo' not in examples_to_include and 'jax_image_classifier' not in examples_to_include }}"
260260

261261
## Huggingface-related dependencies:

mkdocs.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ plugins:
9292
- mkdocstrings:
9393
handlers:
9494
python:
95-
import:
95+
inventories:
9696
- https://docs.python-requests.org/en/master/objects.inv
9797
- https://omegaconf.readthedocs.io/en/latest/objects.inv
9898
- https://lightning.ai/docs/pytorch/stable/objects.inv

project/algorithms/callbacks/classification_metrics.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
import warnings
22
from logging import getLogger as get_logger
3-
from typing import Literal, TypedDict
3+
from typing import Literal, NotRequired, Required, TypedDict
44

55
import lightning
66
import torch
77
import torchmetrics
88
from lightning import LightningModule, Trainer
99
from torch import Tensor
1010
from torchmetrics.classification import MulticlassAccuracy
11-
from typing_extensions import NotRequired, Required, override
11+
from typing_extensions import override
1212

1313
from project.utils.typing_utils.protocols import ClassificationDataModule
1414

project/algorithms/jax_ppo.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@
4040

4141
logger = get_logger(__name__)
4242

43+
# Little compatibility patch for rejax.
44+
if not hasattr(jax, "tree_map"):
45+
setattr(jax, "tree_map", jax.tree.map)
46+
4347
TEnvParams = TypeVar("TEnvParams", bound=gymnax.EnvParams, default=gymnax.EnvParams)
4448
"""Type variable for the env params (`gymnax.EnvParams`)."""
4549

project/algorithms/jax_ppo_test.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import functools
55
import operator
66
import time
7+
import warnings
78
from collections.abc import Callable, Iterable, Sequence
89
from logging import getLogger
910
from pathlib import Path
@@ -156,9 +157,7 @@ def test_ours(
156157
original_datadir: Path,
157158
):
158159
ts, evaluations = results_ours
159-
tensor_regression.check(
160-
jax.tree.map(operator.methodcaller("__array__"), dataclasses.asdict(evaluations))
161-
)
160+
tensor_regression.check(jax.tree.map(np.asarray, dataclasses.asdict(evaluations)))
162161

163162
eval_rng = rng
164163
if isinstance(seed, int):
@@ -311,7 +310,7 @@ def _visualize_rejax(rejax_algo: rejax.PPO, rejax_ts: Any, eval_rng: chex.PRNGKe
311310
actor = functools.partial(
312311
_actor,
313312
actor_ts=actor_ts,
314-
rms_state=rejax_ts.rms_state,
313+
rms_state=rejax_ts.obs_rms_state, # changed in rejax recently.
315314
normalize_observations=rejax_algo.normalize_observations,
316315
)
317316
render_episode(
@@ -463,14 +462,19 @@ def debug_jit_warnings():
463462
# Temporarily make this particular warning into an error to help future-proof our jax code.
464463
import jax._src.deprecations
465464

466-
val_before = jax._src.deprecations._registered_deprecations["tracer-hash"].accelerated
467-
jax._src.deprecations._registered_deprecations["tracer-hash"].accelerated = True
465+
deprecations_to_trigger_error_for = ["tracer-hash"]
466+
values_before = {}
467+
for dep in deprecations_to_trigger_error_for:
468+
if val := jax._src.deprecations._registered_deprecations.get(dep):
469+
values_before[dep] = val.accelerated
470+
val.accelerated = True
471+
else:
472+
warnings.warn(
473+
f"Couldn't find jax deprecation {dep!r} to set to error, might not exist anymore."
474+
)
468475
yield
469-
jax._src.deprecations._registered_deprecations["tracer-hash"].accelerated = val_before
470-
471-
# train_pure_jax(algo, backend="cpu")
472-
# train_rejax(env=algo.env, env_params=algo.env_params, hp=algo.hp, backend="cpu")
473-
# train_lightning(algo, accelerator="cpu")
476+
for dep, previous_value in values_before.items():
477+
jax._src.deprecations._registered_deprecations[dep].accelerated = previous_value
474478

475479

476480
@pytest.fixture

project/utils/remote_launcher_plugin_test.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,13 @@ def test_can_load_configs(command_line_args: str):
9292
# assert isinstance(launcher, remote_launcher_plugin.RemoteSlurmLauncher)
9393
else:
9494
launcher = hydra.utils.instantiate(launcher_config)
95-
assert isinstance(launcher, SlurmLauncher)
95+
# bug: Seems to be some weird reloading of classes happening, causing this test to
96+
# fail when comparing two classes, which have the same name and a probably the same,
97+
# but loaded twice?
98+
assert (
99+
isinstance(launcher, SlurmLauncher)
100+
or type(launcher).__name__ == SlurmLauncher.__name__
101+
)
96102

97103

98104
in_github_CI = os.environ.get("GITHUB_ACTIONS") == "true"

pyproject.toml

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ authors = [
77
{ name = "César Miguel Valdez Córdova", email = "[email protected]" },
88
]
99
dependencies = [
10-
"torch==2.4.1",
10+
"torch>=2.9.1",
1111
"hydra-core>=1.3.2",
1212
"hydra-submitit-launcher>=1.2.0",
1313
"wandb>=0.17.6",
@@ -23,9 +23,9 @@ dependencies = [
2323
"transformers>=4.44.0",
2424
"datasets>=2.21.0",
2525
# Jax-related dependencies:
26-
"jax==0.4.33",
27-
"jaxlib==0.4.33",
28-
"torch-jax-interop>=0.0.7",
26+
"jax",
27+
"flax",
28+
"torch-jax-interop>=0.0.8",
2929
"gymnax @ git+https://www.github.com/lebrice/gymnax@fix-classic-control-rendering",
3030
"rejax>=0.1.0",
3131
"xtils[jitpp] @ git+https://github.com/jessefarebro/xtils",
@@ -34,10 +34,10 @@ dependencies = [
3434
"hydra-colorlog>=1.2.0",
3535
"remote-slurm-executor",
3636
"hydra-auto-schema>=0.0.7",
37-
"hydra-orion-sweeper>=1.6.4 ; python_full_version < '3.11'",
37+
"hydra-orion-sweeper>=1.6.4",
3838
]
3939
readme = "README.md"
40-
requires-python = ">= 3.10"
40+
requires-python = ">= 3.11,<3.14"
4141

4242
[dependency-groups]
4343
dev = [
@@ -53,8 +53,9 @@ dev = [
5353
"pytest-xdist>=3.6.1",
5454
"pytest>=8.3.2",
5555
"ruff>=0.6.0",
56-
"tensor-regression>=0.0.8",
56+
"tensor-regression>=0.1.2",
5757
"copier>=9.5.0",
58+
"tomli>=2.3.0",
5859
]
5960

6061
[project.optional-dependencies]
@@ -71,7 +72,7 @@ docs = [
7172
"mkdocs-macros-plugin>=1.0.5",
7273
"mkdocs-autoref-plugin",
7374
]
74-
gpu = ["jax[cuda12]>=0.4.31; sys_platform == 'linux'"]
75+
gpu = ["jax[cuda13]; sys_platform == 'linux'"]
7576

7677

7778
[tool.pytest.ini_options]
@@ -117,6 +118,18 @@ source = "uv-dynamic-versioning"
117118
[tool.uv]
118119
managed = true
119120

121+
[[tool.uv.index]]
122+
name = "pytorch-cu130"
123+
url = "https://download.pytorch.org/whl/cu130"
124+
explicit = true
125+
126+
120127
[tool.uv.sources]
121128
remote-slurm-executor = { git = "https://github.com/lebrice/remote-slurm-executor", branch = "master" }
122129
mkdocs-autoref-plugin = { git = "https://github.com/lebrice/mkdocs-autoref-plugin", branch = "master" }
130+
torch = [
131+
{ index = "pytorch-cu130", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
132+
]
133+
torchvision = [
134+
{ index = "pytorch-cu130", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
135+
]

pyproject.toml.jinja

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ authors = [
77
]
88
dependencies = [
99

10-
"torch==2.4.1",
10+
"torch>=2.9.1",
1111
"hydra-core>=1.3.2",
1212
"hydra-submitit-launcher>=1.2.0",
1313
"wandb>=0.17.6",
@@ -23,9 +23,9 @@ dependencies = [
2323
"transformers>=4.44.0",
2424
"datasets>=2.21.0",
2525
# Jax-related dependencies:
26-
"jax==0.4.33",
27-
"jaxlib==0.4.33",
28-
"torch-jax-interop>=0.0.7",
26+
"jax",
27+
"flax",
28+
"torch-jax-interop>=0.0.8",
2929
"gymnax @ git+https://www.github.com/lebrice/gymnax@fix-classic-control-rendering",
3030
"rejax>=0.1.0",
3131
"xtils[jitpp] @ git+https://github.com/jessefarebro/xtils",
@@ -34,10 +34,10 @@ dependencies = [
3434
"hydra-colorlog>=1.2.0",
3535
"remote-slurm-executor",
3636
"hydra-auto-schema>=0.0.7",
37-
"hydra-orion-sweeper>=1.6.4 ; python_full_version < '3.11'",
37+
"hydra-orion-sweeper>=1.6.4",
3838
]
3939
readme = "README.md"
40-
requires-python = ">= {{python_version}}"
40+
requires-python = ">= {{python_version}},<3.14"
4141
4242
[dependency-groups]
4343
dev = [
@@ -53,11 +53,11 @@ dev = [
5353
"pytest-xdist>=3.6.1",
5454
"pytest>=8.3.2",
5555
"ruff>=0.6.0",
56-
"tensor-regression>=0.0.8",
56+
"tensor-regression>=0.1.2",
5757
]
5858

5959
[project.optional-dependencies]
60-
gpu = ["jax[cuda12]>=0.4.31; sys_platform == 'linux'"]
60+
gpu = ["jax[cuda13]; sys_platform == 'linux'"]
6161

6262
[tool.pytest.ini_options]
6363
testpaths = ["{{module_name}}", "tests"]
@@ -82,9 +82,6 @@ lint.select = ["E4", "E7", "E9", "F", "I", "UP"]
8282
[tool.uv]
8383
managed = true
8484

85-
[tool.uv.sources]
86-
remote-slurm-executor = { git = "https://github.com/lebrice/remote-slurm-executor", branch = "master" }
87-
8885

8986
[build-system]
9087
requires = ["hatchling", "uv-dynamic-versioning"]
@@ -98,3 +95,17 @@ packages = ["{{module_name}}"]
9895

9996
[tool.hatch.version]
10097
source = "uv-dynamic-versioning"
98+
99+
[[tool.uv.index]]
100+
name = "pytorch-cu130"
101+
url = "https://download.pytorch.org/whl/cu130"
102+
explicit = true
103+
104+
[tool.uv.sources]
105+
remote-slurm-executor = { git = "https://github.com/lebrice/remote-slurm-executor", branch = "master" }
106+
torch = [
107+
{ index = "pytorch-cu130", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
108+
]
109+
torchvision = [
110+
{ index = "pytorch-cu130", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
111+
]

0 commit comments

Comments
 (0)