Skip to content

Commit e190484

Browse files
authored
Enable testing the templating in GitHub CI (#129)
* Enable testing the templating in GitHub CI Signed-off-by: Fabrice Normandin <[email protected]> * Don't try to install jax[cuda12] on MacOS Signed-off-by: Fabrice Normandin <[email protected]> * Replace rm --verbose with rm -v for MacOS (?) Signed-off-by: Fabrice Normandin <[email protected]> * Add data folder to .gitignore Signed-off-by: Fabrice Normandin <[email protected]> * Temporarily set git config in new project in CI Signed-off-by: Fabrice Normandin <[email protected]> * Temporarily set global git config in GitHub CI Signed-off-by: Fabrice Normandin <[email protected]> * Fix error when user.name or user.email is not set Signed-off-by: Fabrice Normandin <[email protected]> * Fix fixture (missing 'git' command) Signed-off-by: Fabrice Normandin <[email protected]> * Add the --no-sync arg to uv remove in copier steps Signed-off-by: Fabrice Normandin <[email protected]> * Fix fixture scope issues causing test conflicts Signed-off-by: Fabrice Normandin <[email protected]> * Minor touchups Signed-off-by: Fabrice Normandin <[email protected]> * Simplify project_root fixture in test_template.py Signed-off-by: Fabrice Normandin <[email protected]> * Fix issue with jax[cuda12] on MacOS Signed-off-by: Fabrice Normandin <[email protected]> * Use optree when jax isn't selected, fix test Signed-off-by: Fabrice Normandin <[email protected]> * Remove other unnecessary reference to jax Signed-off-by: Fabrice Normandin <[email protected]> * Remove jax_typing_utils.py (use xtils.jitpp) Signed-off-by: Fabrice Normandin <[email protected]> * Fix tiny issues with jax ppo example Signed-off-by: Fabrice Normandin <[email protected]> * Fix `jax.jit` uses on methods Signed-off-by: Fabrice Normandin <[email protected]> * Don't fail if no changes to add Signed-off-by: Fabrice Normandin <[email protected]> --------- Signed-off-by: Fabrice Normandin <[email protected]>
1 parent d804a93 commit e190484

File tree

11 files changed

+280
-264
lines changed

11 files changed

+280
-264
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,3 +171,4 @@ lightning_logs
171171
.schemas
172172
tests/.regression_files/**.png
173173
tests/.regression_files/**/*.gif
174+
data

copier.yaml

Lines changed: 42 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -175,73 +175,73 @@ _tasks:
175175
176176
# Remove unwanted examples:
177177
- command: |
178-
rm --verbose {{module_name}}/algorithms/image_classifier*.py
179-
rm --verbose {{module_name}}/configs/algorithm/image_classifier.yaml
180-
rm --verbose {{module_name}}/configs/experiment/example.yaml
181-
rm --verbose {{module_name}}/configs/experiment/cluster_sweep_example.yaml
182-
rm --verbose {{module_name}}/configs/experiment/local_sweep_example.yaml
183-
rm --verbose {{module_name}}/configs/experiment/profiling.yaml
178+
rm -v {{module_name}}/algorithms/image_classifier*.py
179+
rm -v {{module_name}}/configs/algorithm/image_classifier.yaml
180+
rm -v {{module_name}}/configs/experiment/example.yaml
181+
rm -v {{module_name}}/configs/experiment/cluster_sweep_example.yaml
182+
rm -v {{module_name}}/configs/experiment/local_sweep_example.yaml
183+
rm -v {{module_name}}/configs/experiment/profiling.yaml
184184
git add {{module_name}}
185185
git commit -m "Remove image classification example."
186186
when: "{{ run_setup and 'image_classifier' not in examples_to_include }}"
187187
188188
- command: |
189-
rm --verbose {{module_name}}/algorithms/jax_image_classifier*.py
190-
rm --verbose {{module_name}}/configs/algorithm/jax_image_classifier.yaml
191-
rm --verbose {{module_name}}/configs/algorithm/network/jax_cnn.yaml
192-
rm --verbose {{module_name}}/configs/algorithm/network/jax_fcnet.yaml
189+
rm -v {{module_name}}/algorithms/jax_image_classifier*.py
190+
rm -v {{module_name}}/configs/algorithm/jax_image_classifier.yaml
191+
rm -v {{module_name}}/configs/algorithm/network/jax_cnn.yaml
192+
rm -v {{module_name}}/configs/algorithm/network/jax_fcnet.yaml
193193
git add {{module_name}}
194194
git commit -m "Remove jax image classification example."
195195
when: "{{ run_setup and 'jax_image_classifier' not in examples_to_include }}"
196196
197197
# Remove unwanted image classification datamodules and configs
198198
- command: |
199-
rm --verbose {{module_name}}/datamodules/image_classification/image_classification*.py
200-
rm --verbose {{module_name}}/datamodules/image_classification/mnist*.py
201-
rm --verbose {{module_name}}/datamodules/image_classification/fashion_mnist*.py
202-
rm --verbose {{module_name}}/datamodules/image_classification/cifar10*.py
203-
rm --verbose {{module_name}}/datamodules/image_classification/imagenet*.py
204-
rm --verbose {{module_name}}/datamodules/image_classification/inaturalist*.py
205-
rm --verbose {{module_name}}/datamodules/image_classification/__init__.py
206-
rm --verbose {{module_name}}/datamodules/vision*.py
207-
rm --verbose {{module_name}}/configs/datamodule/mnist.yaml
208-
rm --verbose {{module_name}}/configs/datamodule/fashion_mnist.yaml
209-
rm --verbose {{module_name}}/configs/datamodule/cifar10.yaml
210-
rm --verbose {{module_name}}/configs/datamodule/imagenet.yaml
211-
rm --verbose {{module_name}}/configs/datamodule/inaturalist.yaml
212-
rm --verbose {{module_name}}/configs/datamodule/vision.yaml
199+
rm -v {{module_name}}/datamodules/image_classification/image_classification*.py
200+
rm -v {{module_name}}/datamodules/image_classification/mnist*.py
201+
rm -v {{module_name}}/datamodules/image_classification/fashion_mnist*.py
202+
rm -v {{module_name}}/datamodules/image_classification/cifar10*.py
203+
rm -v {{module_name}}/datamodules/image_classification/imagenet*.py
204+
rm -v {{module_name}}/datamodules/image_classification/inaturalist*.py
205+
rm -v {{module_name}}/datamodules/image_classification/__init__.py
206+
rm -v {{module_name}}/datamodules/vision*.py
207+
rm -v {{module_name}}/configs/datamodule/mnist.yaml
208+
rm -v {{module_name}}/configs/datamodule/fashion_mnist.yaml
209+
rm -v {{module_name}}/configs/datamodule/cifar10.yaml
210+
rm -v {{module_name}}/configs/datamodule/imagenet.yaml
211+
rm -v {{module_name}}/configs/datamodule/inaturalist.yaml
212+
rm -v {{module_name}}/configs/datamodule/vision.yaml
213213
rmdir {{module_name}}/datamodules/image_classification
214214
git add {{module_name}}
215215
git commit -m "Remove image classification datamodules and configs."
216216
when: "{{ run_setup and 'image_classifier' not in examples_to_include and 'jax_image_classifier' not in examples_to_include }}"
217217
218218
- command: |
219-
rm --verbose {{module_name}}/algorithms/text_classifier*.py
220-
rm --verbose {{module_name}}/configs/algorithm/text_classifier.yaml
221-
rm --verbose {{module_name}}/configs/experiment/text_classification_example.yaml
222-
rm --verbose {{module_name}}/configs/datamodule/glue_cola.yaml
223-
rm --verbose {{module_name}}/datamodules/text/text_classification*.py
224-
rm --verbose {{module_name}}/datamodules/text/__init__.py
219+
rm -v {{module_name}}/algorithms/text_classifier*.py
220+
rm -v {{module_name}}/configs/algorithm/text_classifier.yaml
221+
rm -v {{module_name}}/configs/experiment/text_classification_example.yaml
222+
rm -v {{module_name}}/configs/datamodule/glue_cola.yaml
223+
rm -v {{module_name}}/datamodules/text/text_classification*.py
224+
rm -v {{module_name}}/datamodules/text/__init__.py
225225
rmdir {{module_name}}/datamodules/text
226226
git add {{module_name}}
227227
git commit -m "Remove text classification example."
228228
when: "{{ run_setup and 'text_classifier' not in examples_to_include }}"
229229
230230
# todo: remove JaxTrainer and project/trainers folder if the JaxPPO example is removed?
231231
- command: |
232-
rm --verbose {{module_name}}/algorithms/jax_ppo*.py
233-
rm --verbose {{module_name}}/trainers/jax_trainer*.py
232+
rm -v {{module_name}}/algorithms/jax_ppo*.py
233+
rm -v {{module_name}}/trainers/jax_trainer*.py
234234
rmdir {{module_name}}/trainers
235-
rm --verbose {{module_name}}/configs/algorithm/jax_ppo.yaml
236-
rm --verbose {{module_name}}/configs/experiment/jax_rl_example.yaml
235+
rm -v {{module_name}}/configs/algorithm/jax_ppo.yaml
236+
rm -v {{module_name}}/configs/experiment/jax_rl_example.yaml
237237
git add {{module_name}}
238238
git commit -m "Remove Jax PPO example and lightning Trainer."
239239
when: "{{ run_setup and 'jax_ppo' not in examples_to_include }}"
240240
241241
- command: |
242-
rm --verbose {{module_name}}/algorithms/llm_finetuning*.py
243-
rm --verbose {{module_name}}/configs/algorithm/llm_finetuning.yaml
244-
rm --verbose {{module_name}}/configs/experiment/llm_finetuning_example.yaml
242+
rm -v {{module_name}}/algorithms/llm_finetuning*.py
243+
rm -v {{module_name}}/configs/algorithm/llm_finetuning.yaml
244+
rm -v {{module_name}}/configs/experiment/llm_finetuning_example.yaml
245245
git add {{module_name}}
246246
git commit -m "Remove LLM fine-tuning example."
247247
when: "{{ run_setup and 'llm_finetuning' not in examples_to_include }}"
@@ -253,15 +253,15 @@ _tasks:
253253
# Remove unneeded dependencies:
254254

255255
## Jax-related dependencies:
256-
- command: uv remove rejax gymnax gymnasium xtils
256+
- command: uv remove rejax gymnax gymnasium xtils --no-sync
257257
when: "{{ run_setup and 'jax_ppo' not in examples_to_include }}"
258-
- command: uv remove jax jaxlib torch-jax-interop
258+
- command: uv remove jax jaxlib torch-jax-interop --no-sync
259259
when: "{{ run_setup and 'jax_ppo' not in examples_to_include and 'jax_image_classifier' not in examples_to_include }}"
260260

261261
## Huggingface-related dependencies:
262-
- command: uv remove evaluate
262+
- command: uv remove evaluate --no-sync
263263
when: "{{run_setup and 'text_classifier' not in examples_to_include }}"
264-
- command: uv remove transformers datasets
264+
- command: uv remove transformers datasets --no-sync
265265
when: "{{ run_setup and 'text_classifier' not in examples_to_include and 'llm_finetuning' not in examples_to_include }}"
266266
- command: |
267267
git add .python-version pyproject.toml uv.lock
@@ -306,7 +306,7 @@ _tasks:
306306
# todo: Causes issues on GitHub CI (asking for user)
307307
- command: |
308308
git add .
309-
git commit -m "Clean project from template"
309+
git commit -m "Clean project from template" || true # don't fail if there are no changes
310310
when: "{{run_setup}}"
311311
# - command: "git commit -m 'Initial commit'"
312312
- "uvx pre-commit install"

project/algorithms/callbacks/samples_per_second.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import time
22
from typing import Any, Generic, Literal
33

4-
import jax
54
import lightning
5+
import optree
66
import torch
77
from lightning import LightningModule, Trainer
88
from lightning.pytorch.utilities.types import STEP_OUTPUT
@@ -182,12 +182,14 @@ def log(
182182
)
183183

184184
def get_num_samples(self, batch: BatchType) -> int:
185+
if isinstance(batch, Tensor):
186+
return batch.shape[0]
185187
if is_sequence_of(batch, Tensor):
186188
return batch[0].shape[0]
187189
if isinstance(batch, dict):
188190
return next(
189191
v.shape[0]
190-
for v in jax.tree.leaves(batch)
192+
for v in optree.tree_leaves(batch) # type: ignore
191193
if isinstance(v, torch.Tensor) and v.ndim > 1
192194
)
193195
raise NotImplementedError(

project/algorithms/jax_ppo.py

Lines changed: 58 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
from __future__ import annotations
88

99
import contextlib
10+
import dataclasses
1011
import functools
1112
import operator
12-
from collections.abc import Callable, Sequence
13+
from collections.abc import Callable, Mapping, Sequence
1314
from logging import getLogger as get_logger
1415
from pathlib import Path
1516
from typing import Any, Generic, TypedDict
@@ -21,25 +22,21 @@
2122
import gymnax.environments.spaces
2223
import jax
2324
import jax.numpy as jnp
24-
import lightning
25-
import lightning.pytorch
26-
import lightning.pytorch.loggers
27-
import lightning.pytorch.loggers.wandb
2825
import numpy as np
2926
import optax
3027
from flax.training.train_state import TrainState
3128
from flax.typing import FrozenVariableDict
3229
from gymnax.environments.environment import Environment
3330
from gymnax.visualize.visualizer import Visualizer
31+
from lightning.pytorch.loggers.wandb import WandbLogger
3432
from matplotlib import pyplot as plt
3533
from rejax.algos.mixins import RMSState
3634
from rejax.evaluate import evaluate
3735
from rejax.networks import DiscretePolicy, GaussianPolicy, VNetwork
3836
from typing_extensions import TypeVar
39-
from xtils.jitpp import Static
37+
from xtils.jitpp import Static, jit
4038

4139
from project.trainers.jax_trainer import JaxCallback, JaxModule, JaxTrainer, get_error_from_metrics
42-
from project.utils.typing_utils.jax_typing_utils import field, jit
4340

4441
logger = get_logger(__name__)
4542

@@ -96,6 +93,48 @@ class PPOState(Generic[TEnvState], flax.struct.PyTreeNode):
9693
data_collection_state: TrajectoryCollectionState[TEnvState]
9794

9895

96+
T = TypeVar("T")
97+
98+
99+
def field(
100+
*,
101+
default: T | dataclasses._MISSING_TYPE = dataclasses.MISSING,
102+
default_factory: Callable[[], T] | dataclasses._MISSING_TYPE = dataclasses.MISSING,
103+
init=True,
104+
repr=True,
105+
hash=None,
106+
compare=True,
107+
metadata: Mapping[Any, Any] | None = None,
108+
kw_only=dataclasses.MISSING,
109+
pytree_node: bool | None = None,
110+
) -> T:
111+
"""Small Typing fix for `flax.struct.field`.
112+
113+
- Add type annotations so it doesn't drop the signature of the `dataclasses.field` function.
114+
- Make the `pytree_node` has a default value of `False` for ints and bools, and `True` for
115+
everything else.
116+
"""
117+
if pytree_node is None and isinstance(default, int): # note: also includes `bool`.
118+
pytree_node = False
119+
if pytree_node is None:
120+
pytree_node = True
121+
if metadata is None:
122+
metadata = {}
123+
else:
124+
metadata = dict(metadata)
125+
metadata.setdefault("pytree_node", pytree_node)
126+
return dataclasses.field(
127+
default=default,
128+
default_factory=default_factory, # type: ignore
129+
init=init,
130+
repr=repr,
131+
hash=hash,
132+
compare=compare,
133+
metadata=metadata,
134+
kw_only=kw_only,
135+
) # type: ignore
136+
137+
99138
class PPOHParams(flax.struct.PyTreeNode):
100139
"""Hyper-parameters for this PPO example.
101140
@@ -512,20 +551,23 @@ def visualize(self, ts: PPOState, gif_path: str | Path, eval_rng: chex.PRNGKey |
512551
render_episode(
513552
actor=actor,
514553
env=self.env,
515-
env_params=jax.tree.map(lambda v: v.item() if v.ndim == 0 else v, self.env_params),
554+
env_params=jax.tree.map(
555+
lambda v: v if isinstance(v, int | float) else v.item() if v.ndim == 0 else v,
556+
self.env_params,
557+
),
516558
gif_path=Path(gif_path),
517559
rng=eval_rng if eval_rng is not None else ts.rng,
518560
)
519561

520562
## These here aren't currently used. They are here to mirror rejax.PPO where the training loop
521563
# is in the algorithm.
522564

523-
@functools.partial(jit, static_argnames=["skip_initial_evaluation"])
565+
@functools.partial(jax.jit, static_argnames="skip_initial_evaluation")
524566
def train(
525567
self,
526568
rng: jax.Array,
527569
train_state: PPOState[TEnvState] | None = None,
528-
skip_initial_evaluation: bool = False,
570+
skip_initial_evaluation: Static[bool] = False,
529571
) -> tuple[PPOState[TEnvState], EvalMetrics]:
530572
"""Full training loop in jax.
531573
@@ -624,9 +666,9 @@ def _normalize_obs(rms_state: RMSState, obs: jax.Array):
624666
return (obs - rms_state.mean) / jnp.sqrt(rms_state.var + 1e-8)
625667

626668

627-
@functools.partial(jit, static_argnames=["num_minibatches"])
669+
@jit
628670
def shuffle_and_split(
629-
data: AdvantageMinibatch, rng: chex.PRNGKey, num_minibatches: int
671+
data: AdvantageMinibatch, rng: chex.PRNGKey, num_minibatches: Static[int]
630672
) -> AdvantageMinibatch:
631673
assert data.trajectories.obs.shape
632674
iteration_size = data.trajectories.obs.shape[0] * data.trajectories.obs.shape[1]
@@ -639,7 +681,7 @@ def shuffle_and_split(
639681
return jax.tree.map(_shuffle_and_split_fn, data)
640682

641683

642-
@functools.partial(jit, static_argnames=["num_minibatches"])
684+
@jit
643685
def _shuffle_and_split(x: jax.Array, permutation: jax.Array, num_minibatches: Static[int]):
644686
x = x.reshape((x.shape[0] * x.shape[1], *x.shape[2:]))
645687
x = jnp.take(x, permutation, axis=0)
@@ -683,7 +725,7 @@ def get_advantages(
683725
return (advantage, transition_data.value), advantage
684726

685727

686-
@functools.partial(jit, static_argnames=["actor"])
728+
@jit
687729
def actor_loss_fn(
688730
params: FrozenVariableDict,
689731
actor: Static[flax.linen.Module],
@@ -710,7 +752,7 @@ def actor_loss_fn(
710752
return pi_loss - ent_coef * entropy
711753

712754

713-
@functools.partial(jit, static_argnames=["critic"])
755+
@jit
714756
def critic_loss_fn(
715757
params: FrozenVariableDict,
716758
critic: Static[flax.linen.Module],
@@ -829,6 +871,6 @@ def on_fit_end(self, trainer: JaxTrainer, module: JaxRLExample, ts: PPOState):
829871

830872
def log_image(self, gif_path: Path, trainer: JaxTrainer, step: int):
831873
for logger in trainer.loggers:
832-
if isinstance(logger, lightning.pytorch.loggers.wandb.WandbLogger):
874+
if isinstance(logger, WandbLogger):
833875
logger.log_image("render_episode", [str(gif_path)], step=step)
834876
return

project/conftest.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,11 @@
7373
from typing import Any, Literal, TypeVar
7474

7575
import hydra.errors
76-
import jax
7776
import lightning
7877
import lightning.pytorch
7978
import lightning.pytorch as pl
8079
import lightning.pytorch.utilities
80+
import optree
8181
import pytest
8282
import tensor_regression.stats
8383
import torch
@@ -377,9 +377,7 @@ def train_dataloader(
377377

378378
# todo: Remove (unused).
379379
@pytest.fixture(scope="session")
380-
def training_batch(
381-
train_dataloader: DataLoader, device: torch.device
382-
) -> tuple[Tensor, ...] | dict[str, Tensor]:
380+
def training_batch(train_dataloader: DataLoader, device: torch.device) -> optree.PyTree[Tensor]:
383381
# Get a batch of data from the dataloader.
384382

385383
# The batch of data will always be the same because the dataloaders are passed a Generator
@@ -391,8 +389,7 @@ def training_batch(
391389
# TODO: This ugliness is because torchvision transforms use the global pytorch RNG!
392390
torch.random.manual_seed(42)
393391
batch = next(dataloader_iterator)
394-
395-
return jax.tree.map(operator.methodcaller("to", device=device), batch)
392+
return optree.tree_map(operator.methodcaller("to", device=device), batch)
396393

397394

398395
@pytest.fixture(autouse=True, scope="function")

0 commit comments

Comments
 (0)