Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ jobs:
max-parallel: 4
matrix:
platform: [ubuntu-latest, macos-latest]
python-version: ["3.10"]
python-version: ["3.11"]
steps:
- uses: actions/checkout@v4
- name: Install the latest version of uv
Expand Down
2 changes: 1 addition & 1 deletion .python-version
Original file line number Diff line number Diff line change
@@ -1 +1 @@
3.10.14
3.11
2 changes: 1 addition & 1 deletion copier.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ _tasks:
## Jax-related dependencies:
- command: uv remove rejax gymnax gymnasium xtils --no-sync
when: "{{ run_setup and 'jax_ppo' not in examples_to_include }}"
- command: uv remove jax jaxlib torch-jax-interop --no-sync
- command: uv remove jax torch-jax-interop flax --no-sync
when: "{{ run_setup and 'jax_ppo' not in examples_to_include and 'jax_image_classifier' not in examples_to_include }}"

## Huggingface-related dependencies:
Expand Down
2 changes: 1 addition & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ plugins:
- mkdocstrings:
handlers:
python:
import:
inventories:
- https://docs.python-requests.org/en/master/objects.inv
- https://omegaconf.readthedocs.io/en/latest/objects.inv
- https://lightning.ai/docs/pytorch/stable/objects.inv
Expand Down
4 changes: 2 additions & 2 deletions project/algorithms/callbacks/classification_metrics.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import warnings
from logging import getLogger as get_logger
from typing import Literal, TypedDict
from typing import Literal, NotRequired, Required, TypedDict

import lightning
import torch
import torchmetrics
from lightning import LightningModule, Trainer
from torch import Tensor
from torchmetrics.classification import MulticlassAccuracy
from typing_extensions import NotRequired, Required, override
from typing_extensions import override

from project.utils.typing_utils.protocols import ClassificationDataModule

Expand Down
4 changes: 4 additions & 0 deletions project/algorithms/jax_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@

logger = get_logger(__name__)

# Little compatibility patch for rejax.
if not hasattr(jax, "tree_map"):
setattr(jax, "tree_map", jax.tree.map)

TEnvParams = TypeVar("TEnvParams", bound=gymnax.EnvParams, default=gymnax.EnvParams)
"""Type variable for the env params (`gymnax.EnvParams`)."""

Expand Down
26 changes: 15 additions & 11 deletions project/algorithms/jax_ppo_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import functools
import operator
import time
import warnings
from collections.abc import Callable, Iterable, Sequence
from logging import getLogger
from pathlib import Path
Expand Down Expand Up @@ -156,9 +157,7 @@ def test_ours(
original_datadir: Path,
):
ts, evaluations = results_ours
tensor_regression.check(
jax.tree.map(operator.methodcaller("__array__"), dataclasses.asdict(evaluations))
)
tensor_regression.check(jax.tree.map(np.asarray, dataclasses.asdict(evaluations)))

eval_rng = rng
if isinstance(seed, int):
Expand Down Expand Up @@ -311,7 +310,7 @@ def _visualize_rejax(rejax_algo: rejax.PPO, rejax_ts: Any, eval_rng: chex.PRNGKe
actor = functools.partial(
_actor,
actor_ts=actor_ts,
rms_state=rejax_ts.rms_state,
rms_state=rejax_ts.obs_rms_state, # changed in rejax recently.
normalize_observations=rejax_algo.normalize_observations,
)
render_episode(
Expand Down Expand Up @@ -463,14 +462,19 @@ def debug_jit_warnings():
# Temporarily make this particular warning into an error to help future-proof our jax code.
import jax._src.deprecations

val_before = jax._src.deprecations._registered_deprecations["tracer-hash"].accelerated
jax._src.deprecations._registered_deprecations["tracer-hash"].accelerated = True
deprecations_to_trigger_error_for = ["tracer-hash"]
values_before = {}
for dep in deprecations_to_trigger_error_for:
if val := jax._src.deprecations._registered_deprecations.get(dep):
values_before[dep] = val.accelerated
val.accelerated = True
else:
warnings.warn(
f"Couldn't find jax deprecation {dep!r} to set to error, might not exist anymore."
)
yield
jax._src.deprecations._registered_deprecations["tracer-hash"].accelerated = val_before

# train_pure_jax(algo, backend="cpu")
# train_rejax(env=algo.env, env_params=algo.env_params, hp=algo.hp, backend="cpu")
# train_lightning(algo, accelerator="cpu")
for dep, previous_value in values_before.items():
jax._src.deprecations._registered_deprecations[dep].accelerated = previous_value


@pytest.fixture
Expand Down
8 changes: 7 additions & 1 deletion project/utils/remote_launcher_plugin_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,13 @@ def test_can_load_configs(command_line_args: str):
# assert isinstance(launcher, remote_launcher_plugin.RemoteSlurmLauncher)
else:
launcher = hydra.utils.instantiate(launcher_config)
assert isinstance(launcher, SlurmLauncher)
# bug: Seems to be some weird reloading of classes happening, causing this test to
# fail when comparing two classes, which have the same name and a probably the same,
# but loaded twice?
assert (
isinstance(launcher, SlurmLauncher)
or type(launcher).__name__ == SlurmLauncher.__name__
)


in_github_CI = os.environ.get("GITHUB_ACTIONS") == "true"
Expand Down
29 changes: 21 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ authors = [
{ name = "César Miguel Valdez Córdova", email = "[email protected]" },
]
dependencies = [
"torch==2.4.1",
"torch>=2.9.1",
"hydra-core>=1.3.2",
"hydra-submitit-launcher>=1.2.0",
"wandb>=0.17.6",
Expand All @@ -23,9 +23,9 @@ dependencies = [
"transformers>=4.44.0",
"datasets>=2.21.0",
# Jax-related dependencies:
"jax==0.4.33",
"jaxlib==0.4.33",
"torch-jax-interop>=0.0.7",
"jax",
"flax",
"torch-jax-interop>=0.0.8",
"gymnax @ git+https://www.github.com/lebrice/gymnax@fix-classic-control-rendering",
"rejax>=0.1.0",
"xtils[jitpp] @ git+https://github.com/jessefarebro/xtils",
Expand All @@ -34,10 +34,10 @@ dependencies = [
"hydra-colorlog>=1.2.0",
"remote-slurm-executor",
"hydra-auto-schema>=0.0.7",
"hydra-orion-sweeper>=1.6.4 ; python_full_version < '3.11'",
"hydra-orion-sweeper>=1.6.4",
]
readme = "README.md"
requires-python = ">= 3.10"
requires-python = ">= 3.11,<3.14"

[dependency-groups]
dev = [
Expand All @@ -53,8 +53,9 @@ dev = [
"pytest-xdist>=3.6.1",
"pytest>=8.3.2",
"ruff>=0.6.0",
"tensor-regression>=0.0.8",
"tensor-regression>=0.1.2",
"copier>=9.5.0",
"tomli>=2.3.0",
]

[project.optional-dependencies]
Expand All @@ -71,7 +72,7 @@ docs = [
"mkdocs-macros-plugin>=1.0.5",
"mkdocs-autoref-plugin",
]
gpu = ["jax[cuda12]>=0.4.31; sys_platform == 'linux'"]
gpu = ["jax[cuda13]; sys_platform == 'linux'"]


[tool.pytest.ini_options]
Expand Down Expand Up @@ -117,6 +118,18 @@ source = "uv-dynamic-versioning"
[tool.uv]
managed = true

[[tool.uv.index]]
name = "pytorch-cu130"
url = "https://download.pytorch.org/whl/cu130"
explicit = true


[tool.uv.sources]
remote-slurm-executor = { git = "https://github.com/lebrice/remote-slurm-executor", branch = "master" }
mkdocs-autoref-plugin = { git = "https://github.com/lebrice/mkdocs-autoref-plugin", branch = "master" }
torch = [
{ index = "pytorch-cu130", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
]
torchvision = [
{ index = "pytorch-cu130", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
]
33 changes: 22 additions & 11 deletions pyproject.toml.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ authors = [
]
dependencies = [

"torch==2.4.1",
"torch>=2.9.1",
"hydra-core>=1.3.2",
"hydra-submitit-launcher>=1.2.0",
"wandb>=0.17.6",
Expand All @@ -23,9 +23,9 @@ dependencies = [
"transformers>=4.44.0",
"datasets>=2.21.0",
# Jax-related dependencies:
"jax==0.4.33",
"jaxlib==0.4.33",
"torch-jax-interop>=0.0.7",
"jax",
"flax",
"torch-jax-interop>=0.0.8",
"gymnax @ git+https://www.github.com/lebrice/gymnax@fix-classic-control-rendering",
"rejax>=0.1.0",
"xtils[jitpp] @ git+https://github.com/jessefarebro/xtils",
Expand All @@ -34,10 +34,10 @@ dependencies = [
"hydra-colorlog>=1.2.0",
"remote-slurm-executor",
"hydra-auto-schema>=0.0.7",
"hydra-orion-sweeper>=1.6.4 ; python_full_version < '3.11'",
"hydra-orion-sweeper>=1.6.4",
]
readme = "README.md"
requires-python = ">= {{python_version}}"
requires-python = ">= {{python_version}},<3.14"

[dependency-groups]
dev = [
Expand All @@ -53,11 +53,11 @@ dev = [
"pytest-xdist>=3.6.1",
"pytest>=8.3.2",
"ruff>=0.6.0",
"tensor-regression>=0.0.8",
"tensor-regression>=0.1.2",
]

[project.optional-dependencies]
gpu = ["jax[cuda12]>=0.4.31; sys_platform == 'linux'"]
gpu = ["jax[cuda13]; sys_platform == 'linux'"]

[tool.pytest.ini_options]
testpaths = ["{{module_name}}", "tests"]
Expand All @@ -82,9 +82,6 @@ lint.select = ["E4", "E7", "E9", "F", "I", "UP"]
[tool.uv]
managed = true

[tool.uv.sources]
remote-slurm-executor = { git = "https://github.com/lebrice/remote-slurm-executor", branch = "master" }


[build-system]
requires = ["hatchling", "uv-dynamic-versioning"]
Expand All @@ -98,3 +95,17 @@ packages = ["{{module_name}}"]

[tool.hatch.version]
source = "uv-dynamic-versioning"

[[tool.uv.index]]
name = "pytorch-cu130"
url = "https://download.pytorch.org/whl/cu130"
explicit = true

[tool.uv.sources]
remote-slurm-executor = { git = "https://github.com/lebrice/remote-slurm-executor", branch = "master" }
torch = [
{ index = "pytorch-cu130", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
]
torchvision = [
{ index = "pytorch-cu130", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
]
9 changes: 3 additions & 6 deletions tests/test_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
)
]

_DEFAULT_PYTHON_VERSION = "3.10"
_DEFAULT_PYTHON_VERSION = "3.11"
"""The default choice of python version in the copier.yaml file."""

_project_fixture_scope = "module"
Expand Down Expand Up @@ -212,11 +212,8 @@ def project_from_template(
[
# These can be very slow but are super important!
# don't run these unless --slow argument is passed to pytest, to save some time.
# TODO: This seems to be the only one that works in the CI atm, because:
# - UV seems unable to download other python versions?
# - Python 3.11 and 3.12 aren't able to install orion atm.
"3.10",
pytest.param("3.11", marks=pytest.mark.slow),
# "3.10",
"3.11",
pytest.param("3.12", marks=pytest.mark.slow),
pytest.param(
"3.13",
Expand Down
Loading