Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,4 @@ lightning_logs
.schemas
tests/.regression_files/**.png
tests/.regression_files/**/*.gif
data
84 changes: 42 additions & 42 deletions copier.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -175,73 +175,73 @@ _tasks:

# Remove unwanted examples:
- command: |
rm --verbose {{module_name}}/algorithms/image_classifier*.py
rm --verbose {{module_name}}/configs/algorithm/image_classifier.yaml
rm --verbose {{module_name}}/configs/experiment/example.yaml
rm --verbose {{module_name}}/configs/experiment/cluster_sweep_example.yaml
rm --verbose {{module_name}}/configs/experiment/local_sweep_example.yaml
rm --verbose {{module_name}}/configs/experiment/profiling.yaml
rm -v {{module_name}}/algorithms/image_classifier*.py
rm -v {{module_name}}/configs/algorithm/image_classifier.yaml
rm -v {{module_name}}/configs/experiment/example.yaml
rm -v {{module_name}}/configs/experiment/cluster_sweep_example.yaml
rm -v {{module_name}}/configs/experiment/local_sweep_example.yaml
rm -v {{module_name}}/configs/experiment/profiling.yaml
git add {{module_name}}
git commit -m "Remove image classification example."
when: "{{ run_setup and 'image_classifier' not in examples_to_include }}"

- command: |
rm --verbose {{module_name}}/algorithms/jax_image_classifier*.py
rm --verbose {{module_name}}/configs/algorithm/jax_image_classifier.yaml
rm --verbose {{module_name}}/configs/algorithm/network/jax_cnn.yaml
rm --verbose {{module_name}}/configs/algorithm/network/jax_fcnet.yaml
rm -v {{module_name}}/algorithms/jax_image_classifier*.py
rm -v {{module_name}}/configs/algorithm/jax_image_classifier.yaml
rm -v {{module_name}}/configs/algorithm/network/jax_cnn.yaml
rm -v {{module_name}}/configs/algorithm/network/jax_fcnet.yaml
git add {{module_name}}
git commit -m "Remove jax image classification example."
when: "{{ run_setup and 'jax_image_classifier' not in examples_to_include }}"

# Remove unwanted image classification datamodules and configs
- command: |
rm --verbose {{module_name}}/datamodules/image_classification/image_classification*.py
rm --verbose {{module_name}}/datamodules/image_classification/mnist*.py
rm --verbose {{module_name}}/datamodules/image_classification/fashion_mnist*.py
rm --verbose {{module_name}}/datamodules/image_classification/cifar10*.py
rm --verbose {{module_name}}/datamodules/image_classification/imagenet*.py
rm --verbose {{module_name}}/datamodules/image_classification/inaturalist*.py
rm --verbose {{module_name}}/datamodules/image_classification/__init__.py
rm --verbose {{module_name}}/datamodules/vision*.py
rm --verbose {{module_name}}/configs/datamodule/mnist.yaml
rm --verbose {{module_name}}/configs/datamodule/fashion_mnist.yaml
rm --verbose {{module_name}}/configs/datamodule/cifar10.yaml
rm --verbose {{module_name}}/configs/datamodule/imagenet.yaml
rm --verbose {{module_name}}/configs/datamodule/inaturalist.yaml
rm --verbose {{module_name}}/configs/datamodule/vision.yaml
rm -v {{module_name}}/datamodules/image_classification/image_classification*.py
rm -v {{module_name}}/datamodules/image_classification/mnist*.py
rm -v {{module_name}}/datamodules/image_classification/fashion_mnist*.py
rm -v {{module_name}}/datamodules/image_classification/cifar10*.py
rm -v {{module_name}}/datamodules/image_classification/imagenet*.py
rm -v {{module_name}}/datamodules/image_classification/inaturalist*.py
rm -v {{module_name}}/datamodules/image_classification/__init__.py
rm -v {{module_name}}/datamodules/vision*.py
rm -v {{module_name}}/configs/datamodule/mnist.yaml
rm -v {{module_name}}/configs/datamodule/fashion_mnist.yaml
rm -v {{module_name}}/configs/datamodule/cifar10.yaml
rm -v {{module_name}}/configs/datamodule/imagenet.yaml
rm -v {{module_name}}/configs/datamodule/inaturalist.yaml
rm -v {{module_name}}/configs/datamodule/vision.yaml
rmdir {{module_name}}/datamodules/image_classification
git add {{module_name}}
git commit -m "Remove image classification datamodules and configs."
when: "{{ run_setup and 'image_classifier' not in examples_to_include and 'jax_image_classifier' not in examples_to_include }}"

- command: |
rm --verbose {{module_name}}/algorithms/text_classifier*.py
rm --verbose {{module_name}}/configs/algorithm/text_classifier.yaml
rm --verbose {{module_name}}/configs/experiment/text_classification_example.yaml
rm --verbose {{module_name}}/configs/datamodule/glue_cola.yaml
rm --verbose {{module_name}}/datamodules/text/text_classification*.py
rm --verbose {{module_name}}/datamodules/text/__init__.py
rm -v {{module_name}}/algorithms/text_classifier*.py
rm -v {{module_name}}/configs/algorithm/text_classifier.yaml
rm -v {{module_name}}/configs/experiment/text_classification_example.yaml
rm -v {{module_name}}/configs/datamodule/glue_cola.yaml
rm -v {{module_name}}/datamodules/text/text_classification*.py
rm -v {{module_name}}/datamodules/text/__init__.py
rmdir {{module_name}}/datamodules/text
git add {{module_name}}
git commit -m "Remove text classification example."
when: "{{ run_setup and 'text_classifier' not in examples_to_include }}"

# todo: remove JaxTrainer and project/trainers folder if the JaxPPO example is removed?
- command: |
rm --verbose {{module_name}}/algorithms/jax_ppo*.py
rm --verbose {{module_name}}/trainers/jax_trainer*.py
rm -v {{module_name}}/algorithms/jax_ppo*.py
rm -v {{module_name}}/trainers/jax_trainer*.py
rmdir {{module_name}}/trainers
rm --verbose {{module_name}}/configs/algorithm/jax_ppo.yaml
rm --verbose {{module_name}}/configs/experiment/jax_rl_example.yaml
rm -v {{module_name}}/configs/algorithm/jax_ppo.yaml
rm -v {{module_name}}/configs/experiment/jax_rl_example.yaml
git add {{module_name}}
git commit -m "Remove Jax PPO example and lightning Trainer."
when: "{{ run_setup and 'jax_ppo' not in examples_to_include }}"

- command: |
rm --verbose {{module_name}}/algorithms/llm_finetuning*.py
rm --verbose {{module_name}}/configs/algorithm/llm_finetuning.yaml
rm --verbose {{module_name}}/configs/experiment/llm_finetuning_example.yaml
rm -v {{module_name}}/algorithms/llm_finetuning*.py
rm -v {{module_name}}/configs/algorithm/llm_finetuning.yaml
rm -v {{module_name}}/configs/experiment/llm_finetuning_example.yaml
git add {{module_name}}
git commit -m "Remove LLM fine-tuning example."
when: "{{ run_setup and 'llm_finetuning' not in examples_to_include }}"
Expand All @@ -253,15 +253,15 @@ _tasks:
# Remove unneeded dependencies:

## Jax-related dependencies:
- command: uv remove rejax gymnax gymnasium xtils
- command: uv remove rejax gymnax gymnasium xtils --no-sync
when: "{{ run_setup and 'jax_ppo' not in examples_to_include }}"
- command: uv remove jax jaxlib torch-jax-interop
- command: uv remove jax jaxlib torch-jax-interop --no-sync
when: "{{ run_setup and 'jax_ppo' not in examples_to_include and 'jax_image_classifier' not in examples_to_include }}"

## Huggingface-related dependencies:
- command: uv remove evaluate
- command: uv remove evaluate --no-sync
when: "{{run_setup and 'text_classifier' not in examples_to_include }}"
- command: uv remove transformers datasets
- command: uv remove transformers datasets --no-sync
when: "{{ run_setup and 'text_classifier' not in examples_to_include and 'llm_finetuning' not in examples_to_include }}"
- command: |
git add .python-version pyproject.toml uv.lock
Expand Down Expand Up @@ -306,7 +306,7 @@ _tasks:
# todo: Causes issues on GitHub CI (asking for user)
- command: |
git add .
git commit -m "Clean project from template"
git commit -m "Clean project from template" || true # don't fail if there are no changes
when: "{{run_setup}}"
# - command: "git commit -m 'Initial commit'"
- "uvx pre-commit install"
Expand Down
6 changes: 4 additions & 2 deletions project/algorithms/callbacks/samples_per_second.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import time
from typing import Any, Generic, Literal

import jax
import lightning
import optree
import torch
from lightning import LightningModule, Trainer
from lightning.pytorch.utilities.types import STEP_OUTPUT
Expand Down Expand Up @@ -182,12 +182,14 @@ def log(
)

def get_num_samples(self, batch: BatchType) -> int:
if isinstance(batch, Tensor):
return batch.shape[0]
if is_sequence_of(batch, Tensor):
return batch[0].shape[0]
if isinstance(batch, dict):
return next(
v.shape[0]
for v in jax.tree.leaves(batch)
for v in optree.tree_leaves(batch) # type: ignore
if isinstance(v, torch.Tensor) and v.ndim > 1
)
raise NotImplementedError(
Expand Down
74 changes: 58 additions & 16 deletions project/algorithms/jax_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
from __future__ import annotations

import contextlib
import dataclasses
import functools
import operator
from collections.abc import Callable, Sequence
from collections.abc import Callable, Mapping, Sequence
from logging import getLogger as get_logger
from pathlib import Path
from typing import Any, Generic, TypedDict
Expand All @@ -21,25 +22,21 @@
import gymnax.environments.spaces
import jax
import jax.numpy as jnp
import lightning
import lightning.pytorch
import lightning.pytorch.loggers
import lightning.pytorch.loggers.wandb
import numpy as np
import optax
from flax.training.train_state import TrainState
from flax.typing import FrozenVariableDict
from gymnax.environments.environment import Environment
from gymnax.visualize.visualizer import Visualizer
from lightning.pytorch.loggers.wandb import WandbLogger
from matplotlib import pyplot as plt
from rejax.algos.mixins import RMSState
from rejax.evaluate import evaluate
from rejax.networks import DiscretePolicy, GaussianPolicy, VNetwork
from typing_extensions import TypeVar
from xtils.jitpp import Static
from xtils.jitpp import Static, jit

from project.trainers.jax_trainer import JaxCallback, JaxModule, JaxTrainer, get_error_from_metrics
from project.utils.typing_utils.jax_typing_utils import field, jit

logger = get_logger(__name__)

Expand Down Expand Up @@ -96,6 +93,48 @@ class PPOState(Generic[TEnvState], flax.struct.PyTreeNode):
data_collection_state: TrajectoryCollectionState[TEnvState]


T = TypeVar("T")


def field(
*,
default: T | dataclasses._MISSING_TYPE = dataclasses.MISSING,
default_factory: Callable[[], T] | dataclasses._MISSING_TYPE = dataclasses.MISSING,
init=True,
repr=True,
hash=None,
compare=True,
metadata: Mapping[Any, Any] | None = None,
kw_only=dataclasses.MISSING,
pytree_node: bool | None = None,
) -> T:
"""Small Typing fix for `flax.struct.field`.

- Add type annotations so it doesn't drop the signature of the `dataclasses.field` function.
- Make the `pytree_node` has a default value of `False` for ints and bools, and `True` for
everything else.
"""
if pytree_node is None and isinstance(default, int): # note: also includes `bool`.
pytree_node = False
if pytree_node is None:
pytree_node = True
if metadata is None:
metadata = {}
else:
metadata = dict(metadata)
metadata.setdefault("pytree_node", pytree_node)
return dataclasses.field(
default=default,
default_factory=default_factory, # type: ignore
init=init,
repr=repr,
hash=hash,
compare=compare,
metadata=metadata,
kw_only=kw_only,
) # type: ignore


class PPOHParams(flax.struct.PyTreeNode):
"""Hyper-parameters for this PPO example.

Expand Down Expand Up @@ -512,20 +551,23 @@ def visualize(self, ts: PPOState, gif_path: str | Path, eval_rng: chex.PRNGKey |
render_episode(
actor=actor,
env=self.env,
env_params=jax.tree.map(lambda v: v.item() if v.ndim == 0 else v, self.env_params),
env_params=jax.tree.map(
lambda v: v if isinstance(v, int | float) else v.item() if v.ndim == 0 else v,
self.env_params,
),
gif_path=Path(gif_path),
rng=eval_rng if eval_rng is not None else ts.rng,
)

## These here aren't currently used. They are here to mirror rejax.PPO where the training loop
# is in the algorithm.

@functools.partial(jit, static_argnames=["skip_initial_evaluation"])
@functools.partial(jax.jit, static_argnames="skip_initial_evaluation")
def train(
self,
rng: jax.Array,
train_state: PPOState[TEnvState] | None = None,
skip_initial_evaluation: bool = False,
skip_initial_evaluation: Static[bool] = False,
) -> tuple[PPOState[TEnvState], EvalMetrics]:
"""Full training loop in jax.

Expand Down Expand Up @@ -624,9 +666,9 @@ def _normalize_obs(rms_state: RMSState, obs: jax.Array):
return (obs - rms_state.mean) / jnp.sqrt(rms_state.var + 1e-8)


@functools.partial(jit, static_argnames=["num_minibatches"])
@jit
def shuffle_and_split(
data: AdvantageMinibatch, rng: chex.PRNGKey, num_minibatches: int
data: AdvantageMinibatch, rng: chex.PRNGKey, num_minibatches: Static[int]
) -> AdvantageMinibatch:
assert data.trajectories.obs.shape
iteration_size = data.trajectories.obs.shape[0] * data.trajectories.obs.shape[1]
Expand All @@ -639,7 +681,7 @@ def shuffle_and_split(
return jax.tree.map(_shuffle_and_split_fn, data)


@functools.partial(jit, static_argnames=["num_minibatches"])
@jit
def _shuffle_and_split(x: jax.Array, permutation: jax.Array, num_minibatches: Static[int]):
x = x.reshape((x.shape[0] * x.shape[1], *x.shape[2:]))
x = jnp.take(x, permutation, axis=0)
Expand Down Expand Up @@ -683,7 +725,7 @@ def get_advantages(
return (advantage, transition_data.value), advantage


@functools.partial(jit, static_argnames=["actor"])
@jit
def actor_loss_fn(
params: FrozenVariableDict,
actor: Static[flax.linen.Module],
Expand All @@ -710,7 +752,7 @@ def actor_loss_fn(
return pi_loss - ent_coef * entropy


@functools.partial(jit, static_argnames=["critic"])
@jit
def critic_loss_fn(
params: FrozenVariableDict,
critic: Static[flax.linen.Module],
Expand Down Expand Up @@ -829,6 +871,6 @@ def on_fit_end(self, trainer: JaxTrainer, module: JaxRLExample, ts: PPOState):

def log_image(self, gif_path: Path, trainer: JaxTrainer, step: int):
for logger in trainer.loggers:
if isinstance(logger, lightning.pytorch.loggers.wandb.WandbLogger):
if isinstance(logger, WandbLogger):
logger.log_image("render_episode", [str(gif_path)], step=step)
return
9 changes: 3 additions & 6 deletions project/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,11 @@
from typing import Any, Literal, TypeVar

import hydra.errors
import jax
import lightning
import lightning.pytorch
import lightning.pytorch as pl
import lightning.pytorch.utilities
import optree
import pytest
import tensor_regression.stats
import torch
Expand Down Expand Up @@ -377,9 +377,7 @@ def train_dataloader(

# todo: Remove (unused).
@pytest.fixture(scope="session")
def training_batch(
train_dataloader: DataLoader, device: torch.device
) -> tuple[Tensor, ...] | dict[str, Tensor]:
def training_batch(train_dataloader: DataLoader, device: torch.device) -> optree.PyTree[Tensor]:
# Get a batch of data from the dataloader.

# The batch of data will always be the same because the dataloaders are passed a Generator
Expand All @@ -391,8 +389,7 @@ def training_batch(
# TODO: This ugliness is because torchvision transforms use the global pytorch RNG!
torch.random.manual_seed(42)
batch = next(dataloader_iterator)

return jax.tree.map(operator.methodcaller("to", device=device), batch)
return optree.tree_map(operator.methodcaller("to", device=device), batch)


@pytest.fixture(autouse=True, scope="function")
Expand Down
Loading