Skip to content

Commit a4b9858

Browse files
authored
Documentation (#1)
1 parent 17c9ccd commit a4b9858

File tree

8 files changed

+242
-61
lines changed

8 files changed

+242
-61
lines changed

.github/workflows/ci.yaml

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -36,30 +36,3 @@ jobs:
3636
- name: Run lints
3737
run: |
3838
make lints
39-
40-
tests:
41-
runs-on: ubuntu-latest
42-
needs:
43-
- lints
44-
strategy:
45-
matrix:
46-
python-version: [ 3.11, 3.12 ]
47-
steps:
48-
- uses: actions/checkout@v3
49-
- name: Set up Python ${{ matrix.python-version }}
50-
uses: actions/setup-python@v3
51-
with:
52-
python-version: ${{ matrix.python-version }}
53-
- uses: astral-sh/setup-uv@v5
54-
with:
55-
version: "latest"
56-
- name: Install dependencies
57-
run: |
58-
uv sync --dev
59-
- name: Run tests
60-
run: |
61-
make tests
62-
- name: Upload coverage reports to Codecov
63-
uses: codecov/codecov-action@v3
64-
env:
65-
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}

.github/workflows/release.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@ jobs:
1010
runs-on: ubuntu-latest
1111
strategy:
1212
matrix:
13-
python-version: [3.9]
13+
python-version: [3.11]
1414
steps:
15-
- uses: actions/checkout@v2
15+
- uses: actions/checkout@v3
1616
- name: Set up Python ${{ matrix.python-version }}
17-
uses: actions/setup-python@v2
17+
uses: actions/setup-python@v3
1818
with:
1919
python-version: ${{ matrix.python-version }}
2020
- name: Install pypa/build

README.md

Lines changed: 221 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,22 @@
11
# blaxbird [blækbɜːd]
22

3+
[![ci](https://github.com/dirmeier/blaxbird/actions/workflows/ci.yaml/badge.svg)](https://github.com/dirmeier/blaxbird/actions/workflows/ci.yaml)
4+
[![version](https://img.shields.io/pypi/v/blaxbird.svg?colorB=black&style=flat)](https://pypi.org/project/blaxbird/)
5+
6+
> A high-level API to build and train NNX models
7+
38
## About
49

5-
A high-level API to build and train NNX models.
10+
`Blaxbird` [blækbɜːd] is a high-level API to easily build NNX models and train them on CPU or GPU.
11+
Using `blaxbird` one can
12+
- concisely define models and loss functions without the usual JAX/Flax verbosity,
13+
- easily define checkpointers that save the best and most current network weights,
14+
- distribute data and model weights over multiple processes or GPUs,
15+
- define hooks that are periodically called during training.
16+
17+
## Example
618

7-
Define the module
19+
To use `blaxbird`, one only needs to define a model, a loss function, and train and validation step functions:
820
```python
921
import optax
1022
from flax import nnx
@@ -18,32 +30,230 @@ def loss_fn(model, images, labels):
1830
logits=logits, labels=labels
1931
).mean()
2032

21-
22-
def train_step(model, rng_key, batch, **kwargs):
33+
def train_step(model, rng_key, batch):
2334
return nnx.value_and_grad(loss_fn)(model, batch["image"], batch["label"])
2435

25-
26-
def val_step(model, rng_key, batch, **kwargs):
36+
def val_step(model, rng_key, batch):
2737
return loss_fn(model, batch["image"], batch["label"])
2838
```
2939

30-
Define the trainer
40+
You can then define construct (and use) a training function like this:
41+
3142
```python
32-
from jax import random as jr
43+
import optax
3344
from flax import nnx
45+
from jax import random as jr
3446

3547
from blaxbird import train_fn
3648

3749
model = CNN(rngs=nnx.rnglib.Rngs(jr.key(1)))
38-
optimizer = get_optimizer(model)
50+
optimizer = nnx.Optimizer(model, optax.adam(1e-4))
3951

52+
train = train_fn(
53+
fns=(train_step, val_step),
54+
n_steps=100,
55+
eval_every_n_steps=10,
56+
n_eval_batches=10
57+
)
58+
train(jr.key(2), model, optimizer, train_itr, val_itr)
59+
```
60+
61+
See a self-contained example in [examples/mnist_classification](examples/mnist_classification).
62+
63+
## Usage
64+
65+
`train_fn` is a higher order function with the following signature:
66+
67+
```python
68+
def train_fn(
69+
*,
70+
fns: tuple[Callable, Callable],
71+
shardings: Optional[tuple[jax.NamedSharding, jax.NamedSharding]] = None,
72+
n_steps: int,
73+
eval_every_n_steps: int,
74+
n_eval_batches: int,
75+
log_to_wandb: bool = False,
76+
hooks: Iterable[Callable] = (),
77+
) -> Callable:
78+
...
79+
```
80+
81+
We briefly explain the more ambiguous argument types below.
82+
83+
### `fns`
84+
85+
`fns` is a required argument consistenf of tuple of two functions, a step function and a validation function.
86+
In the simplest case they look like this:
87+
88+
```python
89+
def train_step(model, rng_key, batch):
90+
return nnx.value_and_grad(loss_fn)(model, batch["image"], batch["label"])
91+
92+
def val_step(model, rng_key, batch):
93+
return loss_fn(model, batch["image"], batch["label"])
94+
```
95+
96+
Both `train_step` and `val_step` have the same arguments and argument types:
97+
- `model` specifies a `nnx.Module`, i.e., a neural network like the CNN shown above.
98+
- `rng_key` is a `jax.random.key` in case you need to generate random numbers.
99+
- `batch` is a sample from a data loader (to be specified later).
100+
101+
The loss function that is called by both computes a *scalar* loss value. B
102+
While `train_step` returns has to return the loss and gradients, `val_step` only needs
103+
to return the loss.
104+
105+
### `shardings`
106+
107+
To specify how data and model weights are distributed over devices and processes,
108+
`blaxbird` uses JAX' [sharding](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) functionality.
109+
110+
`shardings` is again specified by a tuple, one for the model sharding, the other for the data sharding.
111+
An example is shown below, where we only distributed the data over `num_devices` devices.
112+
You can, if you don't want to distribute anything, just set the argument to `None` or not specify it.
113+
114+
```python
115+
def get_sharding():
116+
num_devices = jax.local_device_count()
117+
mesh = jax.sharding.Mesh(
118+
mesh_utils.create_device_mesh((num_devices,)), ("data",)
119+
)
120+
model_sharding = jax.NamedSharding(mesh, jax.sharding.PartitionSpec())
121+
data_sharding = jax.NamedSharding(mesh, jax.sharding.PartitionSpec("data"))
122+
return model_sharding, data_sharding
123+
```
124+
125+
### `hooks`
126+
127+
`hooks` is a list of callables which are periodically called during training.
128+
Each hook has to have the following signature:
129+
130+
```python
131+
def hook_fn(step, *, model, **kwargs) -> None:
132+
...
133+
```
134+
135+
It takes an integer `step` specifying the current training iteration and the model itself.
136+
For instance, if you want to track custom metrics during validation, you could create a hook like this:
137+
138+
```python
139+
def hook_fn(metrics, val_iter, hook_every_n_steps):
140+
def fn(step, *, model, **kwargs):
141+
if step % hook_every_n_steps != 0:
142+
return
143+
for batch in val_iter:
144+
logits = model(batch["image"])
145+
loss = optax.softmax_cross_entropy_with_integer_labels(
146+
logits=logits, labels=batch["label"]
147+
).mean()
148+
metrics.update(loss=loss, logits=logits, labels=batch["label"])
149+
if jax.process_index() == 0:
150+
curr_metrics = ", ".join(
151+
[f"{k}: {v}" for k, v in metrics.compute().items()]
152+
)
153+
logging.info(f"metrics at step {step}: {curr_metrics}")
154+
metrics.reset()
155+
return fn
156+
157+
metrics = nnx.MultiMetric(
158+
accuracy=nnx.metrics.Accuracy(),
159+
loss=nnx.metrics.Average("loss"),
160+
)
161+
hook = hook_fn(metrics, val_iter, eval_every_n_steps)
162+
```
163+
164+
This creates a hook function `hook` that after `eval_every_n_steps` steps iterates over the validation set
165+
computes accuracy and loss, and then logs everything.
166+
167+
To provide multiple hooks to the train function, just concatenate them in a list.
168+
169+
#### A checkpointing `hook`
170+
171+
We provide a convenient hook for checkpointing which can be constructed using
172+
`get_default_checkpointer`. The checkpointer saves both the last `k` checkpoints with the lowest
173+
validation loss and the last training checkpoint.
174+
175+
The signature of the hook is:
176+
177+
```python
178+
def get_default_checkpointer(
179+
outfolder: str,
180+
*,
181+
save_every_n_steps: int,
182+
max_to_keep: int = 5,
183+
) -> tuple[Callable, Callable, Callable]
184+
```
185+
186+
Its arguments are:
187+
- `outfolder`: a folder specifying where to store the checkpoints.
188+
- `save_every_n_steps`: after how many training steps to store a checkpoint.
189+
- `max_to_keep`: the number of checkpoints to keep before starting to remove old checkpoints (to not clog the device).
190+
191+
For instance, you would construct the checkpointing function then like this:
192+
193+
```python
194+
from blaxbird import get_default_checkpointer
195+
196+
hook_save, *_ = get_default_checkpointer(
197+
os.path.join(outfolder, "checkpoints"), save_every_n_steps=100
198+
)
199+
```
200+
201+
### Restoring a run
202+
203+
You can also use `get_default_checkpointer` to restart the run where you left off.
204+
`get_default_checkpointer` in fact returns three functions, one for saving checkpoints and two for restoring
205+
checkpoints:
206+
207+
```python
208+
from blaxbird import get_default_checkpointer
209+
210+
save, restore_best, restore_last = get_default_checkpointer(
211+
os.path.join(outfolder, "checkpoints"), save_every_n_steps=100
212+
)
213+
```
214+
215+
You can then do either of:
216+
217+
```python
218+
model = CNN(rngs=nnx.rnglib.Rngs(jr.key(1)))
219+
optimizer = nnx.Optimizer(model, optax.adam(1e-4))
220+
221+
model, optimizer = restore_best(model, optimizer)
222+
model, optimizer = restore_last(model, optimizer)
223+
```
224+
225+
### Doing training
226+
227+
After having defined train functions, hooks and shardings, you can train your model like this:
228+
229+
```python
40230
train = train_fn(
41231
fns=(train_step, val_step),
42232
n_steps=n_steps,
43-
n_eval_frequency=n_eval_frequency,
233+
eval_every_n_steps=eval_every_n_steps,
44234
n_eval_batches=n_eval_batches,
235+
shardings=(model_sharding, data_sharding),
236+
hooks=hooks,
237+
log_to_wandb=False,
45238
)
46-
train(jr.key(2), model, optimizer, train_itr, val_itr)
239+
train(jr.key(1), model, optimizer, train_itr, val_itr)
240+
```
241+
242+
An self-contained example that also explains how the data loaders should look like can be found
243+
in [examples/mnist_classification](examples/mnist_classification).
244+
245+
## Installation
246+
247+
To install the package from PyPI, call:
248+
249+
```bash
250+
pip install blaxbird
251+
```
252+
253+
To install the latest GitHub <RELEASE>, just call the following on the command line:
254+
255+
```bash
256+
pip install git+https://github.com/dirmeier/blaxbird@<RELEASE>
47257
```
48258

49259
## Author

blaxbird/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""fll: A high-level API for building and training Flax NNX models."""
1+
"""blaxbird: A high-level API for building and training Flax NNX models."""
22

33
__version__ = "0.0.1"
44

blaxbird/_src/checkpointer.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,28 +9,25 @@
99
def get_default_checkpointer(
1010
outfolder: str,
1111
*,
12-
save_every_n_steps: int = 1,
12+
save_every_n_steps: int,
1313
max_to_keep: int = 5,
14-
best_fn: Callable = lambda x: x["val/loss"],
15-
best_mode: str = "min",
1614
) -> tuple[Callable, Callable, Callable]:
1715
"""Construct functions for checkpointing functionality.
1816
1917
Args:
2018
outfolder: a path specifying where checkpoints are stored
2119
save_every_n_steps: how often to store checkpoints
2220
max_to_keep: number of checkpoints to store before they get deleted
23-
best_fn: function that maintains checkpoints using a specific criterion for
24-
quality
25-
best_mode: use `min`, e.g., if your criterion is a loss function.
26-
Use 'max' if the criterion is an ELBO or something.
2721
2822
Returns:
2923
returns function to saev and restore checkpoints
3024
"""
3125
checkpointer = ocp.PyTreeCheckpointer()
3226
options = ocp.CheckpointManagerOptions(
33-
max_to_keep=max_to_keep, create=True, best_mode=best_mode, best_fn=best_fn
27+
max_to_keep=max_to_keep,
28+
create=True,
29+
best_mode="min",
30+
best_fn=lambda x: x["val/loss"],
3431
)
3532
checkpoint_manager = ocp.CheckpointManager(
3633
os.path.join(outfolder, "best"),

blaxbird/_src/trainer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def _eval_step(model, rng_key, metrics, batch, **kwargs):
3232
def train_fn(
3333
*,
3434
fns: tuple[Callable, Callable],
35-
shardings: tuple[jax.NamedSharding, jax.NamedSharding],
35+
shardings: tuple[jax.NamedSharding, jax.NamedSharding] | None = None,
3636
n_steps: int,
3737
eval_every_n_steps: int,
3838
n_eval_batches: int,
@@ -70,8 +70,8 @@ def train(
7070
rng_key: a jax.random.key object
7171
model: a NNX model
7272
optimizer: a nnx.Optimizer object
73-
train_itr: a data laoder
74-
val_itr: a data laoder
73+
train_itr: a data loader
74+
val_itr: a data loader
7575
"""
7676
# get train and val fns
7777
step_fn, eval_fn = _step_and_val_fns(fns)

0 commit comments

Comments
 (0)