Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true


jobs:
linting:
name: Run linting/pre-commit checks
Expand Down Expand Up @@ -48,6 +47,7 @@ jobs:
version: "latest"
# https://github.com/astral-sh/setup-uv?tab=readme-ov-file#github-authentication-token
github-token: ${{ secrets.GITHUB_TOKEN }}
python-version: "3.12"
- name: Install dependencies
run: uv sync --frozen --extra docs
- name: Build the documentation (strict mode)
Expand All @@ -60,7 +60,7 @@ jobs:
max-parallel: 4
matrix:
platform: [ubuntu-latest, macos-latest]
python-version: ["3.11"]
python-version: ["3.12"]
steps:
- uses: actions/checkout@v4
- name: Install the latest version of uv
Expand Down
5 changes: 1 addition & 4 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,7 @@ jobs:
version: "latest"
# https://github.com/astral-sh/setup-uv?tab=readme-ov-file#github-authentication-token
github-token: ${{ secrets.GITHUB_TOKEN }}
cache-suffix: "3.10"

- name: Pin python-version
run: uv python pin 3.10
python-version: "3.12"

- name: Install dependencies
run: uv sync --extra docs --frozen
Expand Down
45 changes: 23 additions & 22 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ default_language_version:

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
rev: v6.0.0
hooks:
# list of supported hooks: https://pre-commit.com/hooks.html
- id: trailing-whitespace
Expand Down Expand Up @@ -32,7 +32,7 @@ repos:

- repo: https://github.com/charliermarsh/ruff-pre-commit
# Ruff version.
rev: "v0.6.9"
rev: "v0.14.11"
hooks:
- id: ruff
args: ["--line-length", "99", "--fix"]
Expand All @@ -41,7 +41,7 @@ repos:
# python docstring formatting
- repo: https://github.com/myint/docformatter
# Don't autoupdate until https://github.com/PyCQA/docformatter/issues/293 is fixed
rev: eb1df347edd128b30cd3368dddc3aa65edcfac38
rev: v1.7.7
hooks:
- id: docformatter
language: python
Expand All @@ -59,33 +59,34 @@ repos:

# jupyter notebook cell output clearing
- repo: https://github.com/kynan/nbstripout
rev: 0.7.1
rev: 0.8.2
hooks:
- id: nbstripout
require_serial: true

# md formatting
- repo: https://github.com/executablebooks/mdformat
rev: 0.7.17
hooks:
- id: mdformat
exclude: "docs/" # terrible, I know, but it's messing up everything with mkdocs fences!
args: ["--number"]
additional_dependencies:
- mdformat-gfm
- mdformat-tables
- mdformat_frontmatter
- mdformat-toc
- mdformat-config
- mdformat-black
# see https://github.com/KyleKing/mdformat-mkdocs
# Doesn't seem to work!
- mdformat-mkdocs[recommended]>=2.1.0
require_serial: true
# TODO: Getting errors with python 3.13. Turning off for now.
# - repo: https://github.com/executablebooks/mdformat
# rev: 1.0.0
# hooks:
# - id: mdformat
# exclude: "docs/" # terrible, I know, but it's messing up everything with mkdocs fences!
# args: ["--number"]
# additional_dependencies:
# - mdformat-gfm
# - mdformat-tables
# - mdformat_frontmatter
# - mdformat-toc
# - mdformat-config
# - mdformat-black
# # see https://github.com/KyleKing/mdformat-mkdocs
# # Doesn't seem to work!
# - mdformat-mkdocs[recommended]>=2.1.0
# require_serial: true

# word spelling linter
- repo: https://github.com/codespell-project/codespell
rev: v2.3.0
rev: v2.4.1
hooks:
- id: codespell
args:
Expand Down
1 change: 0 additions & 1 deletion .python-version

This file was deleted.

6 changes: 2 additions & 4 deletions copier.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,9 @@ python_version:
question: "What Python version do you want to use?"
required: true
choices:
- "3.10"
- "3.11"
- "3.12"
# - "3.13" # todo: There seem to be some dependency issues with python 3.13 that need fixing.
default: "3.10"
- "3.13" # todo: There seem to be some dependency issues with python 3.13 that need fixing.
default: "3.12"

# IDEA: Simplify the repo creation part of the form for new users.
# However, does that maybe make too many assumptions about how people will use this?
Expand Down
5 changes: 2 additions & 3 deletions project/algorithms/callbacks/classification_metrics.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import warnings
from logging import getLogger as get_logger
from typing import Literal, NotRequired, Required, TypedDict
from typing import Literal, NotRequired, Required, TypedDict, override

import lightning
import torch
import torchmetrics
from lightning import LightningModule, Trainer
from torch import Tensor
from torchmetrics.classification import MulticlassAccuracy
from typing_extensions import override

from project.utils.typing_utils.protocols import ClassificationDataModule

Expand All @@ -17,7 +16,7 @@

class ClassificationOutputs(TypedDict, total=False):
"""The outputs that should be minimally returned from the training/val/test_step of
classification LightningModules so that metrics can be added aumatically by the
classification LightningModules so that metrics can be added automatically by the
`ClassificationMetricsCallback`."""

loss: NotRequired[torch.Tensor | float]
Expand Down
11 changes: 3 additions & 8 deletions project/algorithms/callbacks/samples_per_second.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import time
from typing import Any, Generic, Literal
from typing import Any, Literal, override

import lightning
import optree
Expand All @@ -8,18 +8,13 @@
from lightning.pytorch.utilities.types import STEP_OUTPUT
from torch import Tensor
from torch.optim.optimizer import Optimizer
from typing_extensions import TypeVar, override

from project.utils.typing_utils import NestedMapping, is_sequence_of

BatchType = TypeVar(
"BatchType",
bound=torch.Tensor | tuple[torch.Tensor, ...] | NestedMapping[str, torch.Tensor],
contravariant=True,
)
type _Batch = torch.Tensor | tuple[torch.Tensor, ...] | NestedMapping[str, torch.Tensor]


class MeasureSamplesPerSecondCallback(lightning.Callback, Generic[BatchType]):
class MeasureSamplesPerSecondCallback[BatchType: _Batch](lightning.Callback):
def __init__(self, num_optimizers: int | None = None):
super().__init__()
self.last_step_times: dict[Literal["train", "val", "test"], float] = {}
Expand Down
15 changes: 6 additions & 9 deletions project/algorithms/jax_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from collections.abc import Callable, Mapping, Sequence
from logging import getLogger as get_logger
from pathlib import Path
from typing import Any, Generic, TypedDict
from typing import Any, TypedDict

import chex
import flax.linen
Expand Down Expand Up @@ -77,7 +77,7 @@ class AdvantageMinibatch(flax.struct.PyTreeNode):
targets: chex.Array


class TrajectoryCollectionState(Generic[TEnvState], flax.struct.PyTreeNode):
class TrajectoryCollectionState[TEnvState: gymnax.EnvState](flax.struct.PyTreeNode):
"""Struct containing the state related to the collection of data from the environment."""

last_obs: jax.Array
Expand All @@ -88,7 +88,8 @@ class TrajectoryCollectionState(Generic[TEnvState], flax.struct.PyTreeNode):
rng: chex.PRNGKey


class PPOState(Generic[TEnvState], flax.struct.PyTreeNode):
@flax.struct.dataclass
class PPOState[TEnvState: gymnax.EnvState]:
"""Contains all the state of the `JaxRLExample` algorithm."""

actor_ts: TrainState
Expand All @@ -97,10 +98,7 @@ class PPOState(Generic[TEnvState], flax.struct.PyTreeNode):
data_collection_state: TrajectoryCollectionState[TEnvState]


T = TypeVar("T")


def field(
def field[T](
*,
default: T | dataclasses._MISSING_TYPE = dataclasses.MISSING,
default_factory: Callable[[], T] | dataclasses._MISSING_TYPE = dataclasses.MISSING,
Expand Down Expand Up @@ -207,10 +205,9 @@ def get_error_from_ppo_eval_metrics(metrics: EvalMetrics) -> tuple[str, float]:
)


class JaxRLExample(
class JaxRLExample[TEnvState: gymnax.EnvState, TEnvParams: gymnax.EnvParams](
flax.struct.PyTreeNode,
JaxModule[PPOState[TEnvState], TrajectoryWithLastObs, EvalMetrics],
Generic[TEnvState, TEnvParams],
):
"""Example of an RL algorithm written in Jax: PPO, based on `rejax.PPO`.

Expand Down
5 changes: 2 additions & 3 deletions project/algorithms/jax_ppo_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from collections.abc import Callable, Iterable, Sequence
from logging import getLogger
from pathlib import Path
from typing import Any
from typing import Any, override

import chex
import gymnax
Expand All @@ -27,7 +27,6 @@
from lightning.pytorch.utilities.types import STEP_OUTPUT
from tensor_regression import TensorRegressionFixture
from torch.utils.data import DataLoader
from typing_extensions import override

from project.algorithms.callbacks.samples_per_second import MeasureSamplesPerSecondCallback
from project.main_test import experiment_commands_to_test
Expand Down Expand Up @@ -462,7 +461,7 @@ def debug_jit_warnings():
# Temporarily make this particular warning into an error to help future-proof our jax code.
import jax._src.deprecations

deprecations_to_trigger_error_for = ["tracer-hash"]
deprecations_to_trigger_error_for = []
values_before = {}
for dep in deprecations_to_trigger_error_for:
if val := jax._src.deprecations._registered_deprecations.get(dep):
Expand Down
8 changes: 3 additions & 5 deletions project/algorithms/lightning_module_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@

from __future__ import annotations

import abc
import copy
import dataclasses
from abc import ABC
from collections.abc import Mapping
from logging import getLogger as get_logger
from pathlib import Path
from typing import Any, Generic, Literal, TypeVar, overload
from typing import Any, Literal, overload

import lightning
import pytest
Expand All @@ -28,11 +28,9 @@

logger = get_logger(__name__)

LightningModuleType = TypeVar("LightningModuleType", bound=LightningModule)


@pytest.mark.incremental # https://docs.pytest.org/en/stable/example/simple.html#incremental-testing-test-steps
class LightningModuleTests(Generic[LightningModuleType], ABC):
class LightningModuleTests[LightningModuleType: LightningModule](abc.ABC):
"""Suite of generic tests for a LightningModule.

Simply inherit from this class and decorate the class with the appropriate markers to get a set
Expand Down
12 changes: 3 additions & 9 deletions project/algorithms/llm_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from dataclasses import dataclass
from logging import getLogger
from pathlib import Path
from typing import Concatenate, ParamSpec, TypeVar
from typing import Concatenate

import datasets
import datasets.distributed
Expand Down Expand Up @@ -517,10 +517,7 @@ def get_hash_of(config_dataclass) -> str:
return hashlib.md5(vals_string.encode()).hexdigest()


V = TypeVar("V")


def flatten_dict(d: NestedMapping[str, V]) -> dict[str, V]:
def flatten_dict[V](d: NestedMapping[str, V]) -> dict[str, V]:
result = {}
for k, v in d.items():
if isinstance(v, Mapping):
Expand All @@ -530,10 +527,7 @@ def flatten_dict(d: NestedMapping[str, V]) -> dict[str, V]:
return result


P = ParamSpec("P")


def _try_to_load_prepared_dataset_from(
def _try_to_load_prepared_dataset_from[**P](
dataset_path: Path,
_load_from_disk_fn: Callable[Concatenate[Path, P], Dataset | DatasetDict] = load_from_disk,
*_load_from_disk_args: P.args,
Expand Down
3 changes: 1 addition & 2 deletions project/algorithms/text_classifier_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections.abc import Mapping
from pathlib import Path
from typing import Any
from typing import Any, override

import lightning
import pytest
Expand All @@ -9,7 +9,6 @@
from tensor_regression import TensorRegressionFixture
from torch import Tensor
from transformers import PreTrainedModel
from typing_extensions import override

from project.algorithms.text_classifier import TextClassifier
from project.datamodules.text.text_classification import TextClassificationDataModule
Expand Down
6 changes: 2 additions & 4 deletions project/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
from contextlib import contextmanager
from logging import getLogger as get_logger
from pathlib import Path
from typing import Any, Literal, TypeVar
from typing import Any, Literal

import hydra.errors
import lightning
Expand Down Expand Up @@ -554,9 +554,7 @@ def _longest_common_prefix(strings: list[str]):
return shortest[:i]
return shortest

T = TypeVar("T")

def _assert_type(v: Any, t: type[T]) -> T:
def _assert_type[T](v: Any, t: type[T]) -> T:
assert isinstance(v, t)
return v

Expand Down
Loading
Loading