Skip to content

Commit 43341fa

Browse files
committed
Add more details to index.rst
- Explain how CLU metrics work - Explain what goes into jit vs. non-jit context Followups needed: explain gotchas for multi-device metrics
1 parent f80bf7d commit 43341fa

File tree

1 file changed

+44
-6
lines changed

1 file changed

+44
-6
lines changed

docs/index.rst

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,24 @@ metrax Documentation
33

44
**metrax** provides common evaluation metric implementations for JAX.
55

6+
Available Metrics
7+
-----------------
8+
9+
.. toctree::
10+
:maxdepth: 2
11+
12+
Metrax metrics: <metrax>
13+
614
Getting Started
715
---------------
816

9-
Metrics are based on `clu.Metric`.
17+
All metrics inherit from the `clu.Metric
18+
<https://github.com/google/CommonLoopUtils/blob/main/clu/metrics.py>`_ base
19+
class, which provides a functional API consisting of three main methods:
20+
21+
* ``from_model_output`` to create the metric dataclass from inputs
22+
* ``merge`` to update the results
23+
* ``compute`` to get the final result
1024

1125
.. code-block::
1226
@@ -30,11 +44,35 @@ Metrics are based on `clu.Metric`.
3044
# Get result:
3145
result = metric.compute()
3246
47+
Integrate into your training loop
48+
---------------------------------
3349

34-
Metrax API
35-
==========
50+
All Metrax metrics are jittable (they can be used within a ``jax.jit``
51+
function). If your custom metric uses standard JAX operations and no dynamic
52+
shapes, it should be jittable. You can test with the following:
3653

37-
.. toctree::
38-
:maxdepth: 2
54+
.. code-block::
55+
56+
logits = jnp.ones((2, 3))
57+
labels = jnp.ones((2, 3))
58+
jax.jit(metrics.MSE.from_model_output)(logits, labels)
59+
60+
Jittable metrics can be added directly to your train or eval step.
61+
Non-jittable metrics need to go outside the jitted function.
62+
63+
.. code-block::
64+
65+
@jax.jit
66+
def eval_step(logits, labels):
67+
...
68+
outputs['mse'] = metrics.MSE.from_model_output(logits, labels)
69+
outputs['rmse'] = metrics.RMSE.from_model_output(logits, labels)
70+
return outputs
71+
72+
def run_eval():
73+
for logits, labels in eval_dataset:
74+
# Jittable metrics
75+
outputs = eval_step(logits, labels)
76+
# Non-jittable metrics
77+
outputs['sequence_match'] = metrics.SequenceMatch.from_model_outputs(logits, labels)
3978
40-
metrax API <metrax>

0 commit comments

Comments
 (0)