11# blaxbird [ blækbɜːd]
22
3+ [ ![ ci] ( https://github.com/dirmeier/blaxbird/actions/workflows/ci.yaml/badge.svg )] ( https://github.com/dirmeier/blaxbird/actions/workflows/ci.yaml )
4+ [ ![ version] ( https://img.shields.io/pypi/v/blaxbird.svg?colorB=black&style=flat )] ( https://pypi.org/project/blaxbird/ )
5+
6+ > A high-level API to build and train NNX models
7+
38## About
49
5- A high-level API to build and train NNX models.
10+ ` Blaxbird ` [ blækbɜːd] is a high-level API to easily build NNX models and train them on CPU or GPU.
11+ Using ` blaxbird ` one can
12+ - concisely define models and loss functions without the usual JAX/Flax verbosity,
13+ - easily define checkpointers that save the best and most current network weights,
14+ - distribute data and model weights over multiple processes or GPUs,
15+ - define hooks that are periodically called during training.
16+
17+ ## Example
618
7- Define the module
19+ To use ` blaxbird ` , one only needs to define a model, a loss function, and train and validation step functions:
820``` python
921import optax
1022from flax import nnx
@@ -18,32 +30,230 @@ def loss_fn(model, images, labels):
1830 logits = logits, labels = labels
1931 ).mean()
2032
21-
22- def train_step (model , rng_key , batch , ** kwargs ):
33+ def train_step (model , rng_key , batch ):
2334 return nnx.value_and_grad(loss_fn)(model, batch[" image" ], batch[" label" ])
2435
25-
26- def val_step (model , rng_key , batch , ** kwargs ):
36+ def val_step (model , rng_key , batch ):
2737 return loss_fn(model, batch[" image" ], batch[" label" ])
2838```
2939
30- Define the trainer
40+ You can then define construct (and use) a training function like this:
41+
3142``` python
32- from jax import random as jr
43+ import optax
3344from flax import nnx
45+ from jax import random as jr
3446
3547from blaxbird import train_fn
3648
3749model = CNN(rngs = nnx.rnglib.Rngs(jr.key(1 )))
38- optimizer = get_optimizer (model)
50+ optimizer = nnx.Optimizer (model, optax.adam( 1e-4 ) )
3951
52+ train = train_fn(
53+ fns = (train_step, val_step),
54+ n_steps = 100 ,
55+ eval_every_n_steps = 10 ,
56+ n_eval_batches = 10
57+ )
58+ train(jr.key(2 ), model, optimizer, train_itr, val_itr)
59+ ```
60+
61+ See a self-contained example in [ examples/mnist_classification] ( examples/mnist_classification ) .
62+
63+ ## Usage
64+
65+ ` train_fn ` is a higher order function with the following signature:
66+
67+ ``` python
68+ def train_fn (
69+ * ,
70+ fns : tuple[Callable, Callable],
71+ shardings : Optional[tuple[jax.NamedSharding, jax.NamedSharding]] = None ,
72+ n_steps : int ,
73+ eval_every_n_steps : int ,
74+ n_eval_batches : int ,
75+ log_to_wandb : bool = False ,
76+ hooks : Iterable[Callable] = (),
77+ ) -> Callable:
78+ ...
79+ ```
80+
81+ We briefly explain the more ambiguous argument types below.
82+
83+ ### ` fns `
84+
85+ ` fns ` is a required argument consistenf of tuple of two functions, a step function and a validation function.
86+ In the simplest case they look like this:
87+
88+ ``` python
89+ def train_step (model , rng_key , batch ):
90+ return nnx.value_and_grad(loss_fn)(model, batch[" image" ], batch[" label" ])
91+
92+ def val_step (model , rng_key , batch ):
93+ return loss_fn(model, batch[" image" ], batch[" label" ])
94+ ```
95+
96+ Both ` train_step ` and ` val_step ` have the same arguments and argument types:
97+ - ` model ` specifies a ` nnx.Module ` , i.e., a neural network like the CNN shown above.
98+ - ` rng_key ` is a ` jax.random.key ` in case you need to generate random numbers.
99+ - ` batch ` is a sample from a data loader (to be specified later).
100+
101+ The loss function that is called by both computes a * scalar* loss value. B
102+ While ` train_step ` returns has to return the loss and gradients, ` val_step ` only needs
103+ to return the loss.
104+
105+ ### ` shardings `
106+
107+ To specify how data and model weights are distributed over devices and processes,
108+ ` blaxbird ` uses JAX' [ sharding] ( https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html ) functionality.
109+
110+ ` shardings ` is again specified by a tuple, one for the model sharding, the other for the data sharding.
111+ An example is shown below, where we only distributed the data over ` num_devices ` devices.
112+ You can, if you don't want to distribute anything, just set the argument to ` None ` or not specify it.
113+
114+ ``` python
115+ def get_sharding ():
116+ num_devices = jax.local_device_count()
117+ mesh = jax.sharding.Mesh(
118+ mesh_utils.create_device_mesh((num_devices,)), (" data" ,)
119+ )
120+ model_sharding = jax.NamedSharding(mesh, jax.sharding.PartitionSpec())
121+ data_sharding = jax.NamedSharding(mesh, jax.sharding.PartitionSpec(" data" ))
122+ return model_sharding, data_sharding
123+ ```
124+
125+ ### ` hooks `
126+
127+ ` hooks ` is a list of callables which are periodically called during training.
128+ Each hook has to have the following signature:
129+
130+ ``` python
131+ def hook_fn (step , * , model , ** kwargs ) -> None :
132+ ...
133+ ```
134+
135+ It takes an integer ` step ` specifying the current training iteration and the model itself.
136+ For instance, if you want to track custom metrics during validation, you could create a hook like this:
137+
138+ ``` python
139+ def hook_fn (metrics , val_iter , hook_every_n_steps ):
140+ def fn (step , * , model , ** kwargs ):
141+ if step % hook_every_n_steps != 0 :
142+ return
143+ for batch in val_iter:
144+ logits = model(batch[" image" ])
145+ loss = optax.softmax_cross_entropy_with_integer_labels(
146+ logits = logits, labels = batch[" label" ]
147+ ).mean()
148+ metrics.update(loss = loss, logits = logits, labels = batch[" label" ])
149+ if jax.process_index() == 0 :
150+ curr_metrics = " , " .join(
151+ [f " { k} : { v} " for k, v in metrics.compute().items()]
152+ )
153+ logging.info(f " metrics at step { step} : { curr_metrics} " )
154+ metrics.reset()
155+ return fn
156+
157+ metrics = nnx.MultiMetric(
158+ accuracy = nnx.metrics.Accuracy(),
159+ loss = nnx.metrics.Average(" loss" ),
160+ )
161+ hook = hook_fn(metrics, val_iter, eval_every_n_steps)
162+ ```
163+
164+ This creates a hook function ` hook ` that after ` eval_every_n_steps ` steps iterates over the validation set
165+ computes accuracy and loss, and then logs everything.
166+
167+ To provide multiple hooks to the train function, just concatenate them in a list.
168+
169+ #### A checkpointing ` hook `
170+
171+ We provide a convenient hook for checkpointing which can be constructed using
172+ ` get_default_checkpointer ` . The checkpointer saves both the last ` k ` checkpoints with the lowest
173+ validation loss and the last training checkpoint.
174+
175+ The signature of the hook is:
176+
177+ ``` python
178+ def get_default_checkpointer (
179+ outfolder : str ,
180+ * ,
181+ save_every_n_steps : int ,
182+ max_to_keep : int = 5 ,
183+ ) -> tuple[Callable, Callable, Callable]
184+ ```
185+
186+ Its arguments are:
187+ - `outfolder` : a folder specifying where to store the checkpoints.
188+ - `save_every_n_steps` : after how many training steps to store a checkpoint.
189+ - `max_to_keep` : the number of checkpoints to keep before starting to remove old checkpoints (to not clog the device).
190+
191+ For instance, you would construct the checkpointing function then like this:
192+
193+ ```python
194+ from blaxbird import get_default_checkpointer
195+
196+ hook_save, * _ = get_default_checkpointer(
197+ os.path.join(outfolder, " checkpoints" ), save_every_n_steps = 100
198+ )
199+ ```
200+
201+ ### Restoring a run
202+
203+ You can also use ` get_default_checkpointer ` to restart the run where you left off.
204+ ` get_default_checkpointer ` in fact returns three functions, one for saving checkpoints and two for restoring
205+ checkpoints:
206+
207+ ``` python
208+ from blaxbird import get_default_checkpointer
209+
210+ save, restore_best, restore_last = get_default_checkpointer(
211+ os.path.join(outfolder, " checkpoints" ), save_every_n_steps = 100
212+ )
213+ ```
214+
215+ You can then do either of:
216+
217+ ``` python
218+ model = CNN(rngs = nnx.rnglib.Rngs(jr.key(1 )))
219+ optimizer = nnx.Optimizer(model, optax.adam(1e-4 ))
220+
221+ model, optimizer = restore_best(model, optimizer)
222+ model, optimizer = restore_last(model, optimizer)
223+ ```
224+
225+ ### Doing training
226+
227+ After having defined train functions, hooks and shardings, you can train your model like this:
228+
229+ ``` python
40230train = train_fn(
41231 fns = (train_step, val_step),
42232 n_steps = n_steps,
43- n_eval_frequency = n_eval_frequency ,
233+ eval_every_n_steps = eval_every_n_steps ,
44234 n_eval_batches = n_eval_batches,
235+ shardings = (model_sharding, data_sharding),
236+ hooks = hooks,
237+ log_to_wandb = False ,
45238)
46- train(jr.key(2 ), model, optimizer, train_itr, val_itr)
239+ train(jr.key(1 ), model, optimizer, train_itr, val_itr)
240+ ```
241+
242+ An self-contained example that also explains how the data loaders should look like can be found
243+ in [ examples/mnist_classification] ( examples/mnist_classification ) .
244+
245+ ## Installation
246+
247+ To install the package from PyPI, call:
248+
249+ ``` bash
250+ pip install blaxbird
251+ ```
252+
253+ To install the latest GitHub <RELEASE >, just call the following on the command line:
254+
255+ ``` bash
256+ pip install git+https://github.com/dirmeier/blaxbird@< RELEASE>
47257```
48258
49259## Author
0 commit comments