diff --git a/.gitignore b/.gitignore index 5f55b305a..dc92303bd 100644 --- a/.gitignore +++ b/.gitignore @@ -171,3 +171,4 @@ lightning_logs .schemas tests/.regression_files/**.png tests/.regression_files/**/*.gif +data diff --git a/copier.yaml b/copier.yaml index 396476196..118e66803 100644 --- a/copier.yaml +++ b/copier.yaml @@ -175,53 +175,53 @@ _tasks: # Remove unwanted examples: - command: | - rm --verbose {{module_name}}/algorithms/image_classifier*.py - rm --verbose {{module_name}}/configs/algorithm/image_classifier.yaml - rm --verbose {{module_name}}/configs/experiment/example.yaml - rm --verbose {{module_name}}/configs/experiment/cluster_sweep_example.yaml - rm --verbose {{module_name}}/configs/experiment/local_sweep_example.yaml - rm --verbose {{module_name}}/configs/experiment/profiling.yaml + rm -v {{module_name}}/algorithms/image_classifier*.py + rm -v {{module_name}}/configs/algorithm/image_classifier.yaml + rm -v {{module_name}}/configs/experiment/example.yaml + rm -v {{module_name}}/configs/experiment/cluster_sweep_example.yaml + rm -v {{module_name}}/configs/experiment/local_sweep_example.yaml + rm -v {{module_name}}/configs/experiment/profiling.yaml git add {{module_name}} git commit -m "Remove image classification example." when: "{{ run_setup and 'image_classifier' not in examples_to_include }}" - command: | - rm --verbose {{module_name}}/algorithms/jax_image_classifier*.py - rm --verbose {{module_name}}/configs/algorithm/jax_image_classifier.yaml - rm --verbose {{module_name}}/configs/algorithm/network/jax_cnn.yaml - rm --verbose {{module_name}}/configs/algorithm/network/jax_fcnet.yaml + rm -v {{module_name}}/algorithms/jax_image_classifier*.py + rm -v {{module_name}}/configs/algorithm/jax_image_classifier.yaml + rm -v {{module_name}}/configs/algorithm/network/jax_cnn.yaml + rm -v {{module_name}}/configs/algorithm/network/jax_fcnet.yaml git add {{module_name}} git commit -m "Remove jax image classification example." when: "{{ run_setup and 'jax_image_classifier' not in examples_to_include }}" # Remove unwanted image classification datamodules and configs - command: | - rm --verbose {{module_name}}/datamodules/image_classification/image_classification*.py - rm --verbose {{module_name}}/datamodules/image_classification/mnist*.py - rm --verbose {{module_name}}/datamodules/image_classification/fashion_mnist*.py - rm --verbose {{module_name}}/datamodules/image_classification/cifar10*.py - rm --verbose {{module_name}}/datamodules/image_classification/imagenet*.py - rm --verbose {{module_name}}/datamodules/image_classification/inaturalist*.py - rm --verbose {{module_name}}/datamodules/image_classification/__init__.py - rm --verbose {{module_name}}/datamodules/vision*.py - rm --verbose {{module_name}}/configs/datamodule/mnist.yaml - rm --verbose {{module_name}}/configs/datamodule/fashion_mnist.yaml - rm --verbose {{module_name}}/configs/datamodule/cifar10.yaml - rm --verbose {{module_name}}/configs/datamodule/imagenet.yaml - rm --verbose {{module_name}}/configs/datamodule/inaturalist.yaml - rm --verbose {{module_name}}/configs/datamodule/vision.yaml + rm -v {{module_name}}/datamodules/image_classification/image_classification*.py + rm -v {{module_name}}/datamodules/image_classification/mnist*.py + rm -v {{module_name}}/datamodules/image_classification/fashion_mnist*.py + rm -v {{module_name}}/datamodules/image_classification/cifar10*.py + rm -v {{module_name}}/datamodules/image_classification/imagenet*.py + rm -v {{module_name}}/datamodules/image_classification/inaturalist*.py + rm -v {{module_name}}/datamodules/image_classification/__init__.py + rm -v {{module_name}}/datamodules/vision*.py + rm -v {{module_name}}/configs/datamodule/mnist.yaml + rm -v {{module_name}}/configs/datamodule/fashion_mnist.yaml + rm -v {{module_name}}/configs/datamodule/cifar10.yaml + rm -v {{module_name}}/configs/datamodule/imagenet.yaml + rm -v {{module_name}}/configs/datamodule/inaturalist.yaml + rm -v {{module_name}}/configs/datamodule/vision.yaml rmdir {{module_name}}/datamodules/image_classification git add {{module_name}} git commit -m "Remove image classification datamodules and configs." when: "{{ run_setup and 'image_classifier' not in examples_to_include and 'jax_image_classifier' not in examples_to_include }}" - command: | - rm --verbose {{module_name}}/algorithms/text_classifier*.py - rm --verbose {{module_name}}/configs/algorithm/text_classifier.yaml - rm --verbose {{module_name}}/configs/experiment/text_classification_example.yaml - rm --verbose {{module_name}}/configs/datamodule/glue_cola.yaml - rm --verbose {{module_name}}/datamodules/text/text_classification*.py - rm --verbose {{module_name}}/datamodules/text/__init__.py + rm -v {{module_name}}/algorithms/text_classifier*.py + rm -v {{module_name}}/configs/algorithm/text_classifier.yaml + rm -v {{module_name}}/configs/experiment/text_classification_example.yaml + rm -v {{module_name}}/configs/datamodule/glue_cola.yaml + rm -v {{module_name}}/datamodules/text/text_classification*.py + rm -v {{module_name}}/datamodules/text/__init__.py rmdir {{module_name}}/datamodules/text git add {{module_name}} git commit -m "Remove text classification example." @@ -229,19 +229,19 @@ _tasks: # todo: remove JaxTrainer and project/trainers folder if the JaxPPO example is removed? - command: | - rm --verbose {{module_name}}/algorithms/jax_ppo*.py - rm --verbose {{module_name}}/trainers/jax_trainer*.py + rm -v {{module_name}}/algorithms/jax_ppo*.py + rm -v {{module_name}}/trainers/jax_trainer*.py rmdir {{module_name}}/trainers - rm --verbose {{module_name}}/configs/algorithm/jax_ppo.yaml - rm --verbose {{module_name}}/configs/experiment/jax_rl_example.yaml + rm -v {{module_name}}/configs/algorithm/jax_ppo.yaml + rm -v {{module_name}}/configs/experiment/jax_rl_example.yaml git add {{module_name}} git commit -m "Remove Jax PPO example and lightning Trainer." when: "{{ run_setup and 'jax_ppo' not in examples_to_include }}" - command: | - rm --verbose {{module_name}}/algorithms/llm_finetuning*.py - rm --verbose {{module_name}}/configs/algorithm/llm_finetuning.yaml - rm --verbose {{module_name}}/configs/experiment/llm_finetuning_example.yaml + rm -v {{module_name}}/algorithms/llm_finetuning*.py + rm -v {{module_name}}/configs/algorithm/llm_finetuning.yaml + rm -v {{module_name}}/configs/experiment/llm_finetuning_example.yaml git add {{module_name}} git commit -m "Remove LLM fine-tuning example." when: "{{ run_setup and 'llm_finetuning' not in examples_to_include }}" @@ -253,15 +253,15 @@ _tasks: # Remove unneeded dependencies: ## Jax-related dependencies: - - command: uv remove rejax gymnax gymnasium xtils + - command: uv remove rejax gymnax gymnasium xtils --no-sync when: "{{ run_setup and 'jax_ppo' not in examples_to_include }}" - - command: uv remove jax jaxlib torch-jax-interop + - command: uv remove jax jaxlib torch-jax-interop --no-sync when: "{{ run_setup and 'jax_ppo' not in examples_to_include and 'jax_image_classifier' not in examples_to_include }}" ## Huggingface-related dependencies: - - command: uv remove evaluate + - command: uv remove evaluate --no-sync when: "{{run_setup and 'text_classifier' not in examples_to_include }}" - - command: uv remove transformers datasets + - command: uv remove transformers datasets --no-sync when: "{{ run_setup and 'text_classifier' not in examples_to_include and 'llm_finetuning' not in examples_to_include }}" - command: | git add .python-version pyproject.toml uv.lock @@ -306,7 +306,7 @@ _tasks: # todo: Causes issues on GitHub CI (asking for user) - command: | git add . - git commit -m "Clean project from template" + git commit -m "Clean project from template" || true # don't fail if there are no changes when: "{{run_setup}}" # - command: "git commit -m 'Initial commit'" - "uvx pre-commit install" diff --git a/project/algorithms/callbacks/samples_per_second.py b/project/algorithms/callbacks/samples_per_second.py index a6ad820d1..d19c6d93e 100644 --- a/project/algorithms/callbacks/samples_per_second.py +++ b/project/algorithms/callbacks/samples_per_second.py @@ -1,8 +1,8 @@ import time from typing import Any, Generic, Literal -import jax import lightning +import optree import torch from lightning import LightningModule, Trainer from lightning.pytorch.utilities.types import STEP_OUTPUT @@ -182,12 +182,14 @@ def log( ) def get_num_samples(self, batch: BatchType) -> int: + if isinstance(batch, Tensor): + return batch.shape[0] if is_sequence_of(batch, Tensor): return batch[0].shape[0] if isinstance(batch, dict): return next( v.shape[0] - for v in jax.tree.leaves(batch) + for v in optree.tree_leaves(batch) # type: ignore if isinstance(v, torch.Tensor) and v.ndim > 1 ) raise NotImplementedError( diff --git a/project/algorithms/jax_ppo.py b/project/algorithms/jax_ppo.py index 32b8cd4ca..707fed6b7 100644 --- a/project/algorithms/jax_ppo.py +++ b/project/algorithms/jax_ppo.py @@ -7,9 +7,10 @@ from __future__ import annotations import contextlib +import dataclasses import functools import operator -from collections.abc import Callable, Sequence +from collections.abc import Callable, Mapping, Sequence from logging import getLogger as get_logger from pathlib import Path from typing import Any, Generic, TypedDict @@ -21,25 +22,21 @@ import gymnax.environments.spaces import jax import jax.numpy as jnp -import lightning -import lightning.pytorch -import lightning.pytorch.loggers -import lightning.pytorch.loggers.wandb import numpy as np import optax from flax.training.train_state import TrainState from flax.typing import FrozenVariableDict from gymnax.environments.environment import Environment from gymnax.visualize.visualizer import Visualizer +from lightning.pytorch.loggers.wandb import WandbLogger from matplotlib import pyplot as plt from rejax.algos.mixins import RMSState from rejax.evaluate import evaluate from rejax.networks import DiscretePolicy, GaussianPolicy, VNetwork from typing_extensions import TypeVar -from xtils.jitpp import Static +from xtils.jitpp import Static, jit from project.trainers.jax_trainer import JaxCallback, JaxModule, JaxTrainer, get_error_from_metrics -from project.utils.typing_utils.jax_typing_utils import field, jit logger = get_logger(__name__) @@ -96,6 +93,48 @@ class PPOState(Generic[TEnvState], flax.struct.PyTreeNode): data_collection_state: TrajectoryCollectionState[TEnvState] +T = TypeVar("T") + + +def field( + *, + default: T | dataclasses._MISSING_TYPE = dataclasses.MISSING, + default_factory: Callable[[], T] | dataclasses._MISSING_TYPE = dataclasses.MISSING, + init=True, + repr=True, + hash=None, + compare=True, + metadata: Mapping[Any, Any] | None = None, + kw_only=dataclasses.MISSING, + pytree_node: bool | None = None, +) -> T: + """Small Typing fix for `flax.struct.field`. + + - Add type annotations so it doesn't drop the signature of the `dataclasses.field` function. + - Make the `pytree_node` has a default value of `False` for ints and bools, and `True` for + everything else. + """ + if pytree_node is None and isinstance(default, int): # note: also includes `bool`. + pytree_node = False + if pytree_node is None: + pytree_node = True + if metadata is None: + metadata = {} + else: + metadata = dict(metadata) + metadata.setdefault("pytree_node", pytree_node) + return dataclasses.field( + default=default, + default_factory=default_factory, # type: ignore + init=init, + repr=repr, + hash=hash, + compare=compare, + metadata=metadata, + kw_only=kw_only, + ) # type: ignore + + class PPOHParams(flax.struct.PyTreeNode): """Hyper-parameters for this PPO example. @@ -512,7 +551,10 @@ def visualize(self, ts: PPOState, gif_path: str | Path, eval_rng: chex.PRNGKey | render_episode( actor=actor, env=self.env, - env_params=jax.tree.map(lambda v: v.item() if v.ndim == 0 else v, self.env_params), + env_params=jax.tree.map( + lambda v: v if isinstance(v, int | float) else v.item() if v.ndim == 0 else v, + self.env_params, + ), gif_path=Path(gif_path), rng=eval_rng if eval_rng is not None else ts.rng, ) @@ -520,12 +562,12 @@ def visualize(self, ts: PPOState, gif_path: str | Path, eval_rng: chex.PRNGKey | ## These here aren't currently used. They are here to mirror rejax.PPO where the training loop # is in the algorithm. - @functools.partial(jit, static_argnames=["skip_initial_evaluation"]) + @functools.partial(jax.jit, static_argnames="skip_initial_evaluation") def train( self, rng: jax.Array, train_state: PPOState[TEnvState] | None = None, - skip_initial_evaluation: bool = False, + skip_initial_evaluation: Static[bool] = False, ) -> tuple[PPOState[TEnvState], EvalMetrics]: """Full training loop in jax. @@ -624,9 +666,9 @@ def _normalize_obs(rms_state: RMSState, obs: jax.Array): return (obs - rms_state.mean) / jnp.sqrt(rms_state.var + 1e-8) -@functools.partial(jit, static_argnames=["num_minibatches"]) +@jit def shuffle_and_split( - data: AdvantageMinibatch, rng: chex.PRNGKey, num_minibatches: int + data: AdvantageMinibatch, rng: chex.PRNGKey, num_minibatches: Static[int] ) -> AdvantageMinibatch: assert data.trajectories.obs.shape iteration_size = data.trajectories.obs.shape[0] * data.trajectories.obs.shape[1] @@ -639,7 +681,7 @@ def shuffle_and_split( return jax.tree.map(_shuffle_and_split_fn, data) -@functools.partial(jit, static_argnames=["num_minibatches"]) +@jit def _shuffle_and_split(x: jax.Array, permutation: jax.Array, num_minibatches: Static[int]): x = x.reshape((x.shape[0] * x.shape[1], *x.shape[2:])) x = jnp.take(x, permutation, axis=0) @@ -683,7 +725,7 @@ def get_advantages( return (advantage, transition_data.value), advantage -@functools.partial(jit, static_argnames=["actor"]) +@jit def actor_loss_fn( params: FrozenVariableDict, actor: Static[flax.linen.Module], @@ -710,7 +752,7 @@ def actor_loss_fn( return pi_loss - ent_coef * entropy -@functools.partial(jit, static_argnames=["critic"]) +@jit def critic_loss_fn( params: FrozenVariableDict, critic: Static[flax.linen.Module], @@ -829,6 +871,6 @@ def on_fit_end(self, trainer: JaxTrainer, module: JaxRLExample, ts: PPOState): def log_image(self, gif_path: Path, trainer: JaxTrainer, step: int): for logger in trainer.loggers: - if isinstance(logger, lightning.pytorch.loggers.wandb.WandbLogger): + if isinstance(logger, WandbLogger): logger.log_image("render_episode", [str(gif_path)], step=step) return diff --git a/project/conftest.py b/project/conftest.py index 7e39fb587..69dfbd0b1 100644 --- a/project/conftest.py +++ b/project/conftest.py @@ -73,11 +73,11 @@ from typing import Any, Literal, TypeVar import hydra.errors -import jax import lightning import lightning.pytorch import lightning.pytorch as pl import lightning.pytorch.utilities +import optree import pytest import tensor_regression.stats import torch @@ -377,9 +377,7 @@ def train_dataloader( # todo: Remove (unused). @pytest.fixture(scope="session") -def training_batch( - train_dataloader: DataLoader, device: torch.device -) -> tuple[Tensor, ...] | dict[str, Tensor]: +def training_batch(train_dataloader: DataLoader, device: torch.device) -> optree.PyTree[Tensor]: # Get a batch of data from the dataloader. # The batch of data will always be the same because the dataloaders are passed a Generator @@ -391,8 +389,7 @@ def training_batch( # TODO: This ugliness is because torchvision transforms use the global pytorch RNG! torch.random.manual_seed(42) batch = next(dataloader_iterator) - - return jax.tree.map(operator.methodcaller("to", device=device), batch) + return optree.tree_map(operator.methodcaller("to", device=device), batch) @pytest.fixture(autouse=True, scope="function") diff --git a/project/trainers/jax_trainer.py b/project/trainers/jax_trainer.py index f82692c45..3c0642a78 100644 --- a/project/trainers/jax_trainer.py +++ b/project/trainers/jax_trainer.py @@ -12,8 +12,6 @@ from typing import Any, Protocol, runtime_checkable import chex -import flax.core -import flax.linen import flax.struct import jax import jax.experimental @@ -23,10 +21,10 @@ import lightning.pytorch.loggers from hydra.core.hydra_config import HydraConfig from typing_extensions import TypeVar +from xtils.jitpp import Static from project.configs.config import Config from project.experiment import instantiate_trainer, train_and_evaluate -from project.utils.typing_utils.jax_typing_utils import jit Ts = TypeVar("Ts", bound=flax.struct.PyTreeNode, default=flax.struct.PyTreeNode) """Type Variable for the training state.""" @@ -257,13 +255,16 @@ class JaxTrainer(flax.struct.PyTreeNode): verbose: bool = flax.struct.field(pytree_node=False, default=False) - @functools.partial(jit, static_argnames=["skip_initial_evaluation"]) + @functools.partial( + jax.jit, + static_argnames="skip_initial_evaluation", + ) def fit( self, algo: JaxModule[Ts, _B, _MetricsT], rng: chex.PRNGKey, train_state: Ts | None = None, - skip_initial_evaluation: bool = False, + skip_initial_evaluation: Static[bool] = False, ) -> tuple[Ts, _MetricsT]: """Full training loop in pure jax (a lot faster than pytorch-lightning). diff --git a/project/utils/typing_utils/jax_typing_utils.py b/project/utils/typing_utils/jax_typing_utils.py deleted file mode 100644 index 4370bcb6d..000000000 --- a/project/utils/typing_utils/jax_typing_utils.py +++ /dev/null @@ -1,146 +0,0 @@ -"""Small typing helpers for Jax. - -This makes `jax.jit` preserve the signature of the wrapped callable. -""" - -from __future__ import annotations - -import dataclasses -from collections.abc import Callable, Iterable, Mapping, Sequence -from typing import Any, Concatenate, Literal, ParamSpec, overload - -import jax -import jax.experimental -from jax._src.sharding_impls import UNSPECIFIED, Device, UnspecifiedValue -from typing_extensions import TypeVar - -P = ParamSpec("P") -Out = TypeVar("Out", covariant=True) - - -# @functools.wraps(jax.jit) -def jit( - fn: Callable[P, Out], - in_shardings: UnspecifiedValue = UNSPECIFIED, - out_shardings: UnspecifiedValue = UNSPECIFIED, - static_argnums: int | Sequence[int] | None = None, - static_argnames: str | Iterable[str] | None = None, - donate_argnums: int | Sequence[int] | None = None, - donate_argnames: str | Iterable[str] | None = None, - keep_unused: bool = False, - device: Device | None = None, - backend: str | None = None, - inline: bool = False, - abstracted_axes: Any | None = None, -) -> Callable[P, Out]: - # Small type hint fix for jax's `jit` (preserves the signature of the callable). - # TODO: Remove once [our PR to Jax](https://github.com/jax-ml/jax/pull/23720) is merged - - return jax.jit( - fn, - in_shardings=in_shardings, - out_shardings=out_shardings, - static_argnums=static_argnums, - static_argnames=static_argnames, - donate_argnums=donate_argnums, - donate_argnames=donate_argnames, - keep_unused=keep_unused, - device=device, - backend=backend, - inline=inline, - abstracted_axes=abstracted_axes, - ) - - -In = TypeVar("In") -Aux = TypeVar("Aux") - - -# @functools.wraps(jax.value_and_grad) -def value_and_grad( - fn: Callable[Concatenate[In, P], tuple[Out, Aux]], - argnums: Literal[0] = 0, - has_aux: Literal[True] = True, -) -> Callable[Concatenate[In, P], tuple[tuple[Out, Aux], In]]: - # Small type hint fix for jax's `value_and_grad` (preserves the signature of the callable). - return jax.value_and_grad(fn, argnums=argnums, has_aux=has_aux) # type: ignore - - -_T = TypeVar("_T") - - -# @functools.wraps(flax.struct.field) -@overload # `default` and `default_factory` are optional and mutually exclusive. -def field( - *, - default: _T, - init: bool = True, - repr: bool = True, - hash: bool | None = None, - compare: bool = True, - metadata: Mapping[Any, Any] | None = None, - kw_only: bool = ..., - pytree_node: bool = True, -) -> _T: ... -@overload -def field( - *, - default_factory: Callable[[], _T], - init: bool = True, - repr: bool = True, - hash: bool | None = None, - compare: bool = True, - metadata: Mapping[Any, Any] | None = None, - kw_only: bool = ..., - pytree_node: bool = True, -) -> _T: ... -@overload -def field( - *, - init: bool = True, - repr: bool = True, - hash: bool | None = None, - compare: bool = True, - metadata: Mapping[Any, Any] | None = None, - kw_only: bool = ..., - pytree_node: bool = True, -) -> Any: ... - - -def field( - *, - default=dataclasses.MISSING, - default_factory=dataclasses.MISSING, - init=True, - repr=True, - hash=None, - compare=True, - metadata: Mapping[Any, Any] | None = None, - kw_only=dataclasses.MISSING, - pytree_node: bool | None = None, -): - """Small Typing fix for `flax.struct.field`. - - - Add type annotations so it doesn't drop the signature of the `dataclasses.field` function. - - Make the `pytree_node` has a default value of `False` for ints and bools, and `True` for - everything else. - """ - if pytree_node is None and isinstance(default, int): # note: also includes `bool`. - pytree_node = False - if pytree_node is None: - pytree_node = True - if metadata is None: - metadata = {} - else: - metadata = dict(metadata) - metadata.setdefault("pytree_node", pytree_node) - return dataclasses.field( - default=default, - default_factory=default_factory, - init=init, - repr=repr, - hash=hash, - compare=compare, - metadata=metadata, - kw_only=kw_only, - ) # type: ignore diff --git a/pyproject.toml b/pyproject.toml index 8382a58f3..a350a8e8b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ dependencies = [ "tqdm>=4.66.5", "hydra-zen==0.13.1rc1", "matplotlib>=3.9.2", + "optree>=0.15.0", # to get tree ops without needing jax as a dependency. # Hugging Face dependencies: "evaluate>=0.4.2", "transformers>=4.44.0", @@ -70,7 +71,7 @@ docs = [ "mkdocs-macros-plugin>=1.0.5", "mkdocs-autoref-plugin", ] -gpu = ["jax[cuda12]>=0.4.31"] +gpu = ["jax[cuda12]>=0.4.31; sys_platform == 'linux'"] [tool.pytest.ini_options] diff --git a/pyproject.toml.jinja b/pyproject.toml.jinja index 13a653aed..084c40ba7 100644 --- a/pyproject.toml.jinja +++ b/pyproject.toml.jinja @@ -17,6 +17,7 @@ dependencies = [ "tqdm>=4.66.5", "hydra-zen==0.13.1rc1", "matplotlib>=3.9.2", + "optree>=0.15.0", # to get tree ops without needing jax as a dependency. # Hugging Face dependencies: "evaluate>=0.4.2", "transformers>=4.44.0", @@ -56,7 +57,7 @@ dev = [ ] [project.optional-dependencies] -gpu = ["jax[cuda12]>=0.4.31"] +gpu = ["jax[cuda12]>=0.4.31; sys_platform == 'linux'"] [tool.pytest.ini_options] testpaths = ["{{module_name}}", "tests"] diff --git a/tests/test_template.py b/tests/test_template.py index f34c0e945..0d1925531 100644 --- a/tests/test_template.py +++ b/tests/test_template.py @@ -65,6 +65,13 @@ def test_templated_dependencies_are_same_as_in_project(): assert sorted(project_toml["project"]["dependencies"]) == sorted( project_toml_template["project"]["dependencies"] ) + optional_dependencies: dict[str, list[str]] = project_toml["project"]["optional-dependencies"] + optional_dependencies.pop("docs") # not included in the pyproject.toml.jinja template. + optional_dependencies_template: dict[str, list[str]] = project_toml_template["project"][ + "optional-dependencies" + ] + for group_name, group_dependencies in optional_dependencies.items(): + assert sorted(group_dependencies) == sorted(optional_dependencies_template[group_name]) all_template_versions: list[str] = subprocess.getoutput("git tag --sort=-creatordate").split("\n") @@ -121,34 +128,78 @@ def copier_answers( ) -@pytest.fixture(scope=_project_fixture_scope) +@pytest.fixture(scope="function") +def project_root(tmp_path: Path): + tmp_project_dir = tmp_path / "dummy_project" + return tmp_project_dir + + +@pytest.fixture(scope="function") +def temporarily_set_git_config_for_commits(project_root: Path): + # On GitHub Actions, we need to set user.name and user.email for the `git commit` commands + # to succeed. + if not IN_GITHUB_CLOUD_CI: + yield + return + # note: doesn't actually change anything to cd to the project root since we modify the global + # git config (which sucks, but the project root is not yet a git repo at this point, so we + # can't set the local config). + # git = get_git().with_cwd(project_root) + # git_user_name_before = git("config", "--get", "user.name") + # git_user_email_before = git("config", "--get", "user.email") + git_user_name_before = subprocess.getoutput( + ("git", "config", "--global", "--get", "user.name") + ) + git_user_email_before = subprocess.getoutput( + ("git", "config", "--global", "--get", "user.email") + ) + try: + # git("config", "user.name", "your-name") + # git("config", "user.email", "your-email@email.com") + subprocess.check_call(("git", "config", "--global", "user.name", "your-name")) + subprocess.check_call(("git", "config", "--global", "user.email", "your-email@email.com")) + yield + finally: + # git("config", "user.name", git_user_name_before) + # git("config", "user.email", git_user_email_before) + if git_user_email_before: + subprocess.check_call( + ("git", "config", "--global", "user.email", git_user_email_before) + ) + else: + subprocess.check_call(("git", "config", "--global", "--unset", "user.email")) + if git_user_name_before: + subprocess.check_call(("git", "config", "--global", "user.name", git_user_name_before)) + else: + subprocess.check_call(("git", "config", "--global", "--unset", "user.name")) + return + + +@pytest.fixture(scope="function") def project_from_template( template_version_used: str, - tmp_path_factory: pytest.TempPathFactory, + project_root: Path, copier_answers: CopierAnswers, + temporarily_set_git_config_for_commits: None, ): """Fixture that provides the project at a given version.""" - tmp_project_dir = tmp_path_factory.mktemp(f"project_{template_version_used}_test") logger.info( - f"Setting up a project at {tmp_project_dir} using the template at version {template_version_used} with answers: {copier_answers}" + f"Setting up a project at {project_root} using the template at version " + f"{template_version_used} with answers: {copier_answers}" ) with Worker( src_path="." if template_version_used == "HEAD" else "gh:mila-iqia/ResearchTemplate", - dst_path=tmp_project_dir, + dst_path=project_root, vcs_ref=template_version_used, defaults=True, data=dataclasses.asdict(copier_answers), unsafe=True, ) as worker: worker.run_copy() + assert worker.dst_path == project_root and project_root.exists() + yield project_root - yield worker.dst_path - -@pytest.mark.skipif( - IN_GITHUB_CLOUD_CI, - reason="TODO: lots of issues on GitHub CI (commit author, can't install other Python versions).", -) @pytest.mark.skipif(sys.platform == "win32", reason="The template isn't supported on Windows.") @pytest.mark.parametrize( examples_to_include.__name__, @@ -157,7 +208,7 @@ def project_from_template( ids=["none", *examples, "all"], ) @pytest.mark.parametrize( - "python_version", + python_version.__name__, [ # These can be very slow but are super important! # don't run these unless --slow argument is passed to pytest, to save some time. @@ -239,7 +290,8 @@ def add_lit_autoencoder_module( ) new_module = project_root / project_module / "algorithms" / "lit_autoencoder.py" new_module.write_text( - textwrap.dedent("""\ + textwrap.dedent( + """\ import os from torch import optim, nn, utils, Tensor from torchvision.datasets import MNIST @@ -270,16 +322,19 @@ def training_step(self, batch, batch_idx): def configure_optimizers(self): optimizer = optim.Adam(self.parameters(), lr=1e-3) return optimizer - """) + """ + ) ) new_algo_config = ( project_root / project_module / "configs" / "algorithm" / "lit_autoencoder.yaml" ) # Add a Hydra config file for the new module. new_algo_config.write_text( - textwrap.dedent(f"""\ + textwrap.dedent( + f"""\ _target_: {project_module}.algorithms.lit_autoencoder.LitAutoEncoder - """) + """ + ) ) yield # Yield, (let the project update if needed) diff --git a/uv.lock b/uv.lock index ccf8e0854..603e2d2ce 100644 --- a/uv.lock +++ b/uv.lock @@ -1603,8 +1603,8 @@ wheels = [ [package.optional-dependencies] cuda12 = [ - { name = "jax-cuda12-plugin", extra = ["with-cuda"] }, - { name = "jaxlib" }, + { name = "jax-cuda12-plugin", extra = ["with-cuda"], marker = "sys_platform == 'linux'" }, + { name = "jaxlib", marker = "sys_platform == 'linux'" }, ] [[package]] @@ -1621,7 +1621,7 @@ name = "jax-cuda12-plugin" version = "0.4.33" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "jax-cuda12-pjrt" }, + { name = "jax-cuda12-pjrt", marker = "sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/21/84/111b21ed28b082f87a7b0529487766dc0c8b9da17869979e20f33f289d7d/jax_cuda12_plugin-0.4.33-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:80736e03df2a5a0c35e4801f59e4d3c8b94cc8e8a03221763c95dcf7851fd120", size = 14853117 }, @@ -1634,16 +1634,16 @@ wheels = [ [package.optional-dependencies] with-cuda = [ - { name = "nvidia-cublas-cu12" }, - { name = "nvidia-cuda-cupti-cu12" }, - { name = "nvidia-cuda-nvcc-cu12" }, - { name = "nvidia-cuda-runtime-cu12" }, - { name = "nvidia-cudnn-cu12" }, - { name = "nvidia-cufft-cu12" }, - { name = "nvidia-cusolver-cu12" }, - { name = "nvidia-cusparse-cu12" }, - { name = "nvidia-nccl-cu12" }, - { name = "nvidia-nvjitlink-cu12" }, + { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-cuda-cupti-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-cuda-nvcc-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-cuda-runtime-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-cudnn-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-cufft-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-cusolver-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-cusparse-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-nccl-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, ] [[package]] @@ -2683,7 +2683,6 @@ version = "12.1.3.1" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/37/6d/121efd7382d5b0284239f4ab1fc1590d86d34ed4a4a2fdb13b30ca8e5740/nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl", hash = "sha256:ee53ccca76a6fc08fb9701aa95b6ceb242cdaab118c3bb152af4e579af792728", size = 410594774 }, - { url = "https://files.pythonhosted.org/packages/c5/ef/32a375b74bea706c93deea5613552f7c9104f961b21df423f5887eca713b/nvidia_cublas_cu12-12.1.3.1-py3-none-win_amd64.whl", hash = "sha256:2b964d60e8cf11b5e1073d179d85fa340c120e99b3067558f3cf98dd69d02906", size = 439918445 }, ] [[package]] @@ -2692,7 +2691,6 @@ version = "12.1.105" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/7e/00/6b218edd739ecfc60524e585ba8e6b00554dd908de2c9c66c1af3e44e18d/nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:e54fde3983165c624cb79254ae9818a456eb6e87a7fd4d56a2352c24ee542d7e", size = 14109015 }, - { url = "https://files.pythonhosted.org/packages/d0/56/0021e32ea2848c24242f6b56790bd0ccc8bf99f973ca790569c6ca028107/nvidia_cuda_cupti_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:bea8236d13a0ac7190bd2919c3e8e6ce1e402104276e6f9694479e48bb0eb2a4", size = 10154340 }, ] [[package]] @@ -2702,7 +2700,6 @@ source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/8c/ba/aaab2fd2f50f09d6059b654dad092fa1fbe421e4cc85b1422818e6b5346a/nvidia_cuda_nvcc_cu12-12.6.68-py3-none-manylinux2014_aarch64.whl", hash = "sha256:d2faca18a3d5dd48865ad259262f7da43358d0940d53554026102d70c14ea2f9", size = 19232443 }, { url = "https://files.pythonhosted.org/packages/13/d5/cbdb3a9ad5f34bc1892aa7b9d0f44f3b11bc2d7e73d6e02ae1db22fa0bee/nvidia_cuda_nvcc_cu12-12.6.68-py3-none-manylinux2014_x86_64.whl", hash = "sha256:3999aa4a42ac8723c09a8aafd06bc4a6ec1a0b05c53bc96c8d6cf195e84f6935", size = 21147787 }, - { url = "https://files.pythonhosted.org/packages/50/04/86aee13dbfaa2b1491bf83db0abc0afe2bfc1647be3b0be0eebf48e32558/nvidia_cuda_nvcc_cu12-12.6.68-py3-none-win_amd64.whl", hash = "sha256:9c0a18d76f0d1de99ba1d5fd70cffb32c0249e4abc42de9c0504e34d90ff421c", size = 16716194 }, ] [[package]] @@ -2719,7 +2716,6 @@ version = "12.1.105" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/eb/d5/c68b1d2cdfcc59e72e8a5949a37ddb22ae6cade80cd4a57a84d4c8b55472/nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:6e258468ddf5796e25f1dc591a31029fa317d97a0a94ed93468fc86301d61e40", size = 823596 }, - { url = "https://files.pythonhosted.org/packages/9f/e2/7a2b4b5064af56ea8ea2d8b2776c0f2960d95c88716138806121ae52a9c9/nvidia_cuda_runtime_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:dfb46ef84d73fababab44cf03e3b83f80700d27ca300e537f85f636fac474344", size = 821226 }, ] [[package]] @@ -2727,11 +2723,10 @@ name = "nvidia-cudnn-cu12" version = "9.1.0.70" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12" }, + { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/9f/fd/713452cd72343f682b1c7b9321e23829f00b842ceaedcda96e742ea0b0b3/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f", size = 664752741 }, - { url = "https://files.pythonhosted.org/packages/3f/d0/f90ee6956a628f9f04bf467932c0a25e5a7e706a684b896593c06c82f460/nvidia_cudnn_cu12-9.1.0.70-py3-none-win_amd64.whl", hash = "sha256:6278562929433d68365a07a4a1546c237ba2849852c0d4b2262a486e805b977a", size = 679925892 }, ] [[package]] @@ -2740,7 +2735,6 @@ version = "11.0.2.54" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/86/94/eb540db023ce1d162e7bea9f8f5aa781d57c65aed513c33ee9a5123ead4d/nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl", hash = "sha256:794e3948a1aa71fd817c3775866943936774d1c14e7628c74f6f7417224cdf56", size = 121635161 }, - { url = "https://files.pythonhosted.org/packages/f7/57/7927a3aa0e19927dfed30256d1c854caf991655d847a4e7c01fe87e3d4ac/nvidia_cufft_cu12-11.0.2.54-py3-none-win_amd64.whl", hash = "sha256:d9ac353f78ff89951da4af698f80870b1534ed69993f10a4cf1d96f21357e253", size = 121344196 }, ] [[package]] @@ -2756,13 +2750,12 @@ name = "nvidia-cusolver-cu12" version = "11.4.5.107" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12" }, - { name = "nvidia-cusparse-cu12" }, - { name = "nvidia-nvjitlink-cu12" }, + { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-cusparse-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/bc/1d/8de1e5c67099015c834315e333911273a8c6aaba78923dd1d1e25fc5f217/nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl", hash = "sha256:8a7ec542f0412294b15072fa7dab71d31334014a69f953004ea7a118206fe0dd", size = 124161928 }, - { url = "https://files.pythonhosted.org/packages/b8/80/8fca0bf819122a631c3976b6fc517c1b10741b643b94046bd8dd451522c5/nvidia_cusolver_cu12-11.4.5.107-py3-none-win_amd64.whl", hash = "sha256:74e0c3a24c78612192a74fcd90dd117f1cf21dea4822e66d89e8ea80e3cd2da5", size = 121643081 }, ] [[package]] @@ -2770,11 +2763,10 @@ name = "nvidia-cusparse-cu12" version = "12.1.0.106" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12" }, + { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/65/5b/cfaeebf25cd9fdec14338ccb16f6b2c4c7fa9163aefcf057d86b9cc248bb/nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:f3b50f42cf363f86ab21f720998517a659a48131e8d538dc02f8768237bd884c", size = 195958278 }, - { url = "https://files.pythonhosted.org/packages/0f/95/48fdbba24c93614d1ecd35bc6bdc6087bd17cbacc3abc4b05a9c2a1ca232/nvidia_cusparse_cu12-12.1.0.106-py3-none-win_amd64.whl", hash = "sha256:b798237e81b9719373e8fae8d4f091b70a0cf09d9d85c95a557e11df2d8e9a5a", size = 195414588 }, ] [[package]] @@ -2793,7 +2785,6 @@ source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/58/8c/69c9e39cd6bfa813852a94e9bd3c075045e2707d163e9dc2326c82d2c330/nvidia_nvjitlink_cu12-12.6.68-py3-none-manylinux2014_aarch64.whl", hash = "sha256:b3fd0779845f68b92063ab1393abab1ed0a23412fc520df79a8190d098b5cd6b", size = 19253287 }, { url = "https://files.pythonhosted.org/packages/a8/48/a9775d377cb95585fb188b469387f58ba6738e268de22eae2ad4cedb2c41/nvidia_nvjitlink_cu12-12.6.68-py3-none-manylinux2014_x86_64.whl", hash = "sha256:125a6c2a44e96386dda634e13d944e60b07a0402d391a070e8fb4104b34ea1ab", size = 19725597 }, - { url = "https://files.pythonhosted.org/packages/00/d5/02af3b39427ed71e8c40b6912271499ec186a72405bcb7e4ca26ff70678c/nvidia_nvjitlink_cu12-12.6.68-py3-none-win_amd64.whl", hash = "sha256:a55744c98d70317c5e23db14866a8cc2b733f7324509e941fc96276f9f37801d", size = 161730369 }, ] [[package]] @@ -2843,6 +2834,75 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a3/8b/7032a6788205e9da398a8a33e1030ee9a22bd9289126e5afed9aac33bcde/optax-0.2.3-py3-none-any.whl", hash = "sha256:083e603dcd731d7e74d99f71c12f77937dd53f79001b4c09c290e4f47dd2e94f", size = 289647 }, ] +[[package]] +name = "optree" +version = "0.15.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/4c/10/37411ac8cf8cb07b9db9ddc3e36f84869fd1cabcee3d6af8d347c28744f2/optree-0.15.0.tar.gz", hash = "sha256:d00a45e3b192093ef2cd32bf0d541ecbfc93c1bd73a5f3fe36293499f28a50cf", size = 171403 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/17/07/eaae03b46385dd8b9987433d63352a9d2ebf00df3bda785501665eac9b79/optree-0.15.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:6e73e390520a545ebcaa0b77fd77943a85d1952df658268129e6c523d4d38972", size = 609468 }, + { url = "https://files.pythonhosted.org/packages/aa/d5/34605464c764853aa45c969c0673975897732ff0a975ad3f3ba461af2e3f/optree-0.15.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c45593a818c67b72fd0beaeaa6410fa3c5debd39af500127fa367f8ee1f4bd8e", size = 329906 }, + { url = "https://files.pythonhosted.org/packages/f9/27/3d42845627a39f130c65c1a84a60b769557942da04e5eebf843bb24c30f4/optree-0.15.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e4e440de109529ce919d0a0a4fa234d3b949da6f99630c9406c9f21160800831", size = 360986 }, + { url = "https://files.pythonhosted.org/packages/60/1f/285429c597e5c56c42732716400e399a0c4a45d97f11420f8bc0840560c0/optree-0.15.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7614ad2f7bde7b905c897011be573d89a9cb5cf851784ee8efb0020d8e067b27", size = 406114 }, + { url = "https://files.pythonhosted.org/packages/14/6d/33fc164efc396622f0cd0a9c9de67c14c8161cadc10f73b2187d7e5d786d/optree-0.15.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:655ab99f9f9570fbb124f81fdf7e480250b59b1f1d9bd07df04c8751eecc1450", size = 403859 }, + { url = "https://files.pythonhosted.org/packages/d0/e0/5e69e4988f9d98deaeaec7141f1e5cbef3c18300ae32e0a4d0037a1ea366/optree-0.15.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e63b965b62f461513983095750fd1331cad5674153bf3811bd7e2748044df4cd", size = 374122 }, + { url = "https://files.pythonhosted.org/packages/4b/f5/8b4ea051730c461e8957652ae58a895e5cc740b162adfe12f15d144d7c76/optree-0.15.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:14e515b011d965bd3f7aeb021bb523265cb49fde47be0033ba5601e386fff90a", size = 397070 }, + { url = "https://files.pythonhosted.org/packages/e2/2f/70e2bbe99e8dbc2004ac4fc314946e3ab00ccb28a71daea6a4d67a15d8c4/optree-0.15.0-cp310-cp310-win32.whl", hash = "sha256:27031f507828c18606047e695129e9ec9678cd4321f57856da59c7fcc8f8666c", size = 268873 }, + { url = "https://files.pythonhosted.org/packages/ee/15/2db908ee685ef73340224abdc8d312c8d31836d808af0ae12e6a85e1b9ca/optree-0.15.0-cp310-cp310-win_amd64.whl", hash = "sha256:f0392bebcd24fc70ca9a397c1eb2373727fa775e1007f27f3983c50f16a98e45", size = 297297 }, + { url = "https://files.pythonhosted.org/packages/cc/14/443f9dd63d158b5d01ba0e834d3aaa224d081d3459793e21764ef73619f8/optree-0.15.0-cp310-cp310-win_arm64.whl", hash = "sha256:c3122f73eca03e38712ceee16a6acf75d5244ba8b8d1adf5cd6613d1a60a6c26", size = 296510 }, + { url = "https://files.pythonhosted.org/packages/e8/89/1267444a074b6e4402b5399b73b930a7b86cde054a41cecb9694be726a92/optree-0.15.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:c15d98e6f587badb9df67d67fa914fcfa0b63db2db270951915c563816c29f3d", size = 629067 }, + { url = "https://files.pythonhosted.org/packages/98/a5/f8d6c278ce72b2ed8c1ebac968c3c652832bd2d9e65ec81fe6a21082c313/optree-0.15.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f8d58949ef132beb3a025ace512a71a0fcf92e0e5ef350f289f33a782ae6cb85", size = 338192 }, + { url = "https://files.pythonhosted.org/packages/c8/88/3508c7ed217a37d35e2f30187dd2073c8130c09f80e47d3a0c9cadf42b14/optree-0.15.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f71d4759de0c4abc132dab69d1aa6ea4561ba748efabeee7b25db57c08652b79", size = 373749 }, + { url = "https://files.pythonhosted.org/packages/e1/05/f20f4ee6164e0e2b13e8cd588ba46f80fc0e8945a585d34f7250bcf7c31c/optree-0.15.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4ba65d4c48d76bd5caac7f0b1b8db55223c1c3707d26f6d1d2ff18baf6f81850", size = 421870 }, + { url = "https://files.pythonhosted.org/packages/1f/e8/fcaeb5de1e53349dc45867706cd2a79c45500868e1bc010904511c4d7d58/optree-0.15.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aad3878acdb082701e5f77a153cd86af8819659bfa7e27debd0dc1a52f16c365", size = 418929 }, + { url = "https://files.pythonhosted.org/packages/b8/16/11d0954c39c3526e2d4198628abba80cb2858c6963e52aebd89c38fac93a/optree-0.15.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6676b8c3f4cd4c8d8d052b66767a9e4cf852627bf256da6e49d2c38a95f07712", size = 386787 }, + { url = "https://files.pythonhosted.org/packages/aa/3d/52a75740d6c449073d4bb54da382f6368553f285fb5a680b27dd198dd839/optree-0.15.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a1f185b0d21bc4dda1f4fd03f5ba9e2bc9d28ca14bce3ce3d36b5817140a345e", size = 410434 }, + { url = "https://files.pythonhosted.org/packages/ef/0f/fe199a99fbb3c3f33c1a98328d4bf9a1b42a72df3aa2de0c9046873075b7/optree-0.15.0-cp311-cp311-win32.whl", hash = "sha256:927b579a76c13b9328580c09dd4a9947646531f0a371a170a785002c50dedb94", size = 274360 }, + { url = "https://files.pythonhosted.org/packages/b0/86/9743be6eac8cc5ef69fa2b6585a36254aca0815714f57a0763bcfa774906/optree-0.15.0-cp311-cp311-win_amd64.whl", hash = "sha256:d6525d6a550a1030957e5205e57a415d608a9f7561154e0fb29670e967424578", size = 306609 }, + { url = "https://files.pythonhosted.org/packages/87/41/4b4130f921c886ec83ba53809da83bce13039653be37c56ea6b9b7c4bd2b/optree-0.15.0-cp311-cp311-win_arm64.whl", hash = "sha256:081e8bed7583b625819659d68288365bd4348b3c4281935a6ecfa53c93619b13", size = 304507 }, + { url = "https://files.pythonhosted.org/packages/32/a5/2589d9790a6dd7c4b1dd22bd228238c575ec5384ce5bc16a30e7f43cdd99/optree-0.15.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:ba2eee9de9d57e145b4c1a71749f7f8b8fe1c645abbb306d4a26cfa45a9cdbb5", size = 639476 }, + { url = "https://files.pythonhosted.org/packages/a5/2c/5363abf03c8d47ad7bc3b45a735cbdf24a10f99f82e776ef2949ffce77c6/optree-0.15.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4aad5023686cd7caad68d70ad3706b82cfe9ae8ff9a13c08c1edef2a9b4c9d72", size = 342569 }, + { url = "https://files.pythonhosted.org/packages/0c/e2/d2ee348f26cbfba68d51f971112db1e4560f60db95598c88ce63d4c99204/optree-0.15.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9810e84466025da55ce19ac6b2b79a5cb2c0c1349d318a17504f6e44528221f8", size = 369181 }, + { url = "https://files.pythonhosted.org/packages/ec/f8/dafdc4bc40d699b0a8b3ae9cdd5c33984a9cbf1f964151f9608e24e409d9/optree-0.15.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:20b07d8a097b810d68b0ee35f287c1f0b7c9844133ada613a92cc10bade9cdbe", size = 416191 }, + { url = "https://files.pythonhosted.org/packages/81/aa/c2027d6314fa62787639ff2bba42884fd80a65f53b4a3bdc4fd87700649d/optree-0.15.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0304ec416258edebe2cd2a1ef71770e43405d5e7366ecbc134c520b4ab44d155", size = 413041 }, + { url = "https://files.pythonhosted.org/packages/ac/17/01050e00e74291925796309b38dfbbffc03e2893dc929dd9e48d361456ee/optree-0.15.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:759a72e6dcca3e7239d202a253e1e8e44e8df5033a5e178df585778ac85ddd13", size = 381367 }, + { url = "https://files.pythonhosted.org/packages/86/f0/a00cf9f2cf1e8d54f71116ad5eea73fc5b1177644283704535bb8e43090e/optree-0.15.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:01a0dc75c594c884d0ca502b8d169cec538e19a70883d2e5f5b9b08fce740958", size = 404769 }, + { url = "https://files.pythonhosted.org/packages/60/e6/abc48777d38c0ab429b84c91fabfa76c64991bc98ef10538d6fc6d8e88f4/optree-0.15.0-cp312-cp312-win32.whl", hash = "sha256:7e10e5c2a8110f5f4fbc999ff8580d1db3a915f851f63f602fff3bbd250ffa20", size = 275492 }, + { url = "https://files.pythonhosted.org/packages/25/33/cd41ab38ef313874eb2000f1037ccce001dd680873713cc2d1a2ae5d0041/optree-0.15.0-cp312-cp312-win_amd64.whl", hash = "sha256:def5b08f219c31edd029b47624e689ffa07747b0694222156f28a28d341d29ac", size = 307368 }, + { url = "https://files.pythonhosted.org/packages/df/04/9ed5f07d78b303c4bcadcf3ec763358ea64472d1a005dc1249278df66600/optree-0.15.0-cp312-cp312-win_arm64.whl", hash = "sha256:8ec6d3040b1cbfe3f0bc045a3302ee9f9e329c2cd96e928360d22e1cfd9d973a", size = 300733 }, + { url = "https://files.pythonhosted.org/packages/70/e3/99135565340ac34857b6edbb3df6e15eb35aa7160f900f12d83f71d38eeb/optree-0.15.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:4ab606720ae319cb43da47c71d7d5fa7cfbb6a02e6da4857331e6f93800c970e", size = 647666 }, + { url = "https://files.pythonhosted.org/packages/93/fc/c9c04494d2ab54f98f8d8c18cb5095a81ad23fb64105cb05e926bb3d1a0c/optree-0.15.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:9cfc5771115f85b0bfa8f72cce1599186fd6a0ea71c8154d8b2751d9170be428", size = 346205 }, + { url = "https://files.pythonhosted.org/packages/b8/c8/25484ec6784435d63e64e87a7dca32760eed4ca60ddd4be1ca14ed421e08/optree-0.15.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f958a20a311854aaab8bdd0f124aab5b9848f07976b54da3e95526a491aa860", size = 372427 }, + { url = "https://files.pythonhosted.org/packages/a3/27/615a2987137fd8add27e62217527e49f7fd2ec391bbec354dbc59a0cd0af/optree-0.15.0-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:47ce7e9d81eaed5a05004df1fa279d2608e063dd5eb236e9c95803b4fa0a286c", size = 421611 }, + { url = "https://files.pythonhosted.org/packages/98/e7/0106ee7ebec2e4cfa09f0a3b857e03c80bc80521227f4d64c289ac9601e8/optree-0.15.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c6d6ab3717d48e0e747d9e348e23be1fa0f8a812f73632face6303c438d259ba", size = 415804 }, + { url = "https://files.pythonhosted.org/packages/47/0e/a40ccedb0bac4a894b2b0d17d915b7e8cd3fdc44a5104026cec4102a53fb/optree-0.15.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9c7d101a15be39a9c7c4afae9f0bb85f682eb7d719117e2f9e5fb39c9f6f2c92", size = 386208 }, + { url = "https://files.pythonhosted.org/packages/c0/f6/1d98163b283048d80cb0c8e67d086a1e0ecabe35004d7180d0f28303f611/optree-0.15.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aae337ab30b45a096eb5b4ffc3ad8909731617543a7eb288e0b297b9d10a241f", size = 409276 }, + { url = "https://files.pythonhosted.org/packages/e4/65/4275cd763bafb22c41443b7c7bfd5986b9c97f984537669c356e2c7008a1/optree-0.15.0-cp313-cp313-win32.whl", hash = "sha256:eb9c51d728485f5908111191b5403a3f9bc310d121a981f29fad45750b9ff89c", size = 277924 }, + { url = "https://files.pythonhosted.org/packages/b2/fe/3d9f5463609ac6cb5857d0040739ae7709e1fbc134c58f0dd7cb3d2d93d5/optree-0.15.0-cp313-cp313-win_amd64.whl", hash = "sha256:7f00e6f011f021ae470efe070ec4d2339fb1a8cd0dcdd16fe3dab782a47aba45", size = 309617 }, + { url = "https://files.pythonhosted.org/packages/7d/bb/b7a443be665a91307cd78196018bc45bb0c6ab04fe265f009698311c7bc7/optree-0.15.0-cp313-cp313-win_arm64.whl", hash = "sha256:17990fbc7f4c461de7ae546fc5661f6a248c3dcee966c89c2e2e5ad7f6228bae", size = 303390 }, + { url = "https://files.pythonhosted.org/packages/30/8d/c6977786cc40626f24068d9bd6acbff377b9bed0cd64b7faff1bc3f2d60c/optree-0.15.0-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:b31c88af70e3f5c14ff2aacd38c4076e6cde98f75169fe0bb59543f01bfb9719", size = 747468 }, + { url = "https://files.pythonhosted.org/packages/5c/3f/e5b50c7af3712a97d0341744cd8ebd8ab19cd38561e26caa7a8a10fdf702/optree-0.15.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:bc440f81f738d9c822030c3b4f53b6dec9ceb52410f02fd06b9338dc25a8447f", size = 392638 }, + { url = "https://files.pythonhosted.org/packages/54/3b/1aab1a9dba5c8eb85810ba2683b42b057b1ce1de9bb7ca08c5d35448979e/optree-0.15.0-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:76ffc2dd8c754e95495163dde55b38dc37e6712b6a3bc7f2190b0547a2c403bb", size = 390030 }, + { url = "https://files.pythonhosted.org/packages/fc/95/195be32055ce48cdf8beb31eda74f4643c6a443f1367346c6a8f6606c9eb/optree-0.15.0-cp313-cp313t-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9fa9fb0197cd7b5f2b1fa7e05d30946b3b79bcfc3608fe54dbfc67969895cac9", size = 437350 }, + { url = "https://files.pythonhosted.org/packages/02/01/9eefcdd9112aebc52db668a87f1df2f45c784e34e5a35182d3dad731fc92/optree-0.15.0-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6828639b01ba1177c04875dd9529d938d7b28122c97e7ae14ec41c68ec22826c", size = 434500 }, + { url = "https://files.pythonhosted.org/packages/6d/b3/0749c3a752b4921d54aa88eb1246315a008b5199c6d0e6127bd672a21c87/optree-0.15.0-cp313-cp313t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:93c74eed0f52818c30212dba4867f5672e498567bad49dcdffbe8db6703a0d65", size = 404306 }, + { url = "https://files.pythonhosted.org/packages/04/91/1538085ace361f99280a26f9ccdc6023b22e7d0e23fdb849e2ef104fb54f/optree-0.15.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:12188f6832c29dac37385a2f42fce961e303349909cff6d40e21cb27a8d09023", size = 423468 }, + { url = "https://files.pythonhosted.org/packages/c8/ec/c4d2203a6f1206790ea8b794df943a16918649b23cd03ec6beb7f75a81df/optree-0.15.0-cp313-cp313t-win32.whl", hash = "sha256:d7b8ce7d13580985922dcfbda515da3f004cd7cb1b03320b96ea32d8cfd76392", size = 308837 }, + { url = "https://files.pythonhosted.org/packages/66/db/567ea8d2874eb15e9a8a82598ab84a6b809ca16df083b7650deffb8a6171/optree-0.15.0-cp313-cp313t-win_amd64.whl", hash = "sha256:daccdb583abaab14346f0af316ee570152a5c058e7b9fb09d8f8171fe751f2b3", size = 344774 }, + { url = "https://files.pythonhosted.org/packages/ae/67/26804ab0222febadab6e27af4f4e607fad2d9520bb32422842509b79cf63/optree-0.15.0-cp313-cp313t-win_arm64.whl", hash = "sha256:e0162a36a6cedb0829efe980d0b370d4e5970fdb28a6609daa2c906d547add5f", size = 337502 }, + { url = "https://files.pythonhosted.org/packages/69/2c/dddb2166b43bbbb7a83d17a0585ec6df2f5312cec9a81816ad26f9b779f8/optree-0.15.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:b30673fe30d4d77eef18534420491c27837f0b55dfe18107cfd9eca39a62de3b", size = 336150 }, + { url = "https://files.pythonhosted.org/packages/92/72/f5b995baf6d38bc6211ef806e94fc1dff4acc49de809fdf7e139603d78f5/optree-0.15.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d0f378d08b8a09f7e495c49cd94141c1acebc2aa7d567d7dd2cb44a707f29268", size = 364766 }, + { url = "https://files.pythonhosted.org/packages/2b/e4/9d6142910afa81a8ad3bcf67d42953c843ae1f5118c522bfc564bfa89ffa/optree-0.15.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:90dae741d683cbc47cba16a1b4af3c0d5d8c1042efb7c4aec7664a4f3f07eca2", size = 400303 }, + { url = "https://files.pythonhosted.org/packages/7f/64/4791bfd4bef08f159d285c1f3fc8d8cc5dc9bdb6c50290d4c382fee7605c/optree-0.15.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:cf790dd21dcaa0857888c03233276f5513821abfe605964e825837a30a24f0d7", size = 299596 }, + { url = "https://files.pythonhosted.org/packages/e7/df/b490f1bdd65b0d798ac9b46c145cefaf7724c6e8b90fb0371cd3f7fb4c60/optree-0.15.0-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:21afadec56475f2a13670b8ecf7b767af4feb3ba5bd3a246cbbd8c1822e2a664", size = 346694 }, + { url = "https://files.pythonhosted.org/packages/cd/6c/4ea0616ef4e3a3fff5490ebdc8083fa8f3a3a3105ab93222e4fd97fc4f2a/optree-0.15.0-pp311-pypy311_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1a39bccc63223e040f36eb8b413fa1f94a190289eb82e7b384ed32d95d1ffd67", size = 377442 }, + { url = "https://files.pythonhosted.org/packages/99/5d/8cf3c44adf9e20ea267bd218ee829e715c1987b877b7bce74b37986997e4/optree-0.15.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:06aed485ab9c94f5b45a18f956bcb89bf6bad29632421da69da268cb49adb37b", size = 413200 }, + { url = "https://files.pythonhosted.org/packages/6f/3c/1cc96fb1faaae6f5e0d34556a68aac18be5b0d8ac675b7dd236e4c82c6b2/optree-0.15.0-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:07e9d75867ca39cce98375249b83a2033b0313cbfa32cbd06f93f7bc15104afc", size = 308873 }, +] + [[package]] name = "orbax-checkpoint" version = "0.6.4" @@ -3924,6 +3984,7 @@ dependencies = [ { name = "jaxlib" }, { name = "lightning" }, { name = "matplotlib" }, + { name = "optree" }, { name = "rejax" }, { name = "remote-slurm-executor" }, { name = "rich" }, @@ -3951,7 +4012,7 @@ docs = [ { name = "mkdocstrings", extra = ["python"] }, ] gpu = [ - { name = "jax", extra = ["cuda12"] }, + { name = "jax", extra = ["cuda12"], marker = "sys_platform == 'linux'" }, ] [package.dev-dependencies] @@ -3986,7 +4047,7 @@ requires-dist = [ { name = "hydra-submitit-launcher", specifier = ">=1.2.0" }, { name = "hydra-zen", specifier = "==0.13.1rc1" }, { name = "jax", specifier = "==0.4.33" }, - { name = "jax", extras = ["cuda12"], marker = "extra == 'gpu'", specifier = ">=0.4.31" }, + { name = "jax", extras = ["cuda12"], marker = "sys_platform == 'linux' and extra == 'gpu'", specifier = ">=0.4.31" }, { name = "jaxlib", specifier = "==0.4.33" }, { name = "lightning", specifier = ">=2.4.0" }, { name = "matplotlib", specifier = ">=3.9.2" }, @@ -4000,6 +4061,7 @@ requires-dist = [ { name = "mkdocs-section-index", marker = "extra == 'docs'", specifier = ">=0.3.9" }, { name = "mkdocs-video", marker = "extra == 'docs'", specifier = ">=1.5.0" }, { name = "mkdocstrings", extras = ["python"], marker = "extra == 'docs'", specifier = ">=0.26.2" }, + { name = "optree", specifier = ">=0.15.0" }, { name = "rejax", specifier = ">=0.1.0" }, { name = "remote-slurm-executor", git = "https://github.com/lebrice/remote-slurm-executor?branch=master" }, { name = "rich", specifier = ">=13.7.1" },