@@ -3,10 +3,24 @@ metrax Documentation
33
44**metrax ** provides common evaluation metric implementations for JAX.
55
6+ Available Metrics
7+ -----------------
8+
9+ .. toctree ::
10+ :maxdepth: 2
11+
12+ Metrax metrics: <metrax >
13+
614Getting Started
715---------------
816
9- Metrics are based on `clu.Metric `.
17+ All metrics inherit from the `clu.Metric
18+ <https://github.com/google/CommonLoopUtils/blob/main/clu/metrics.py> `_ base
19+ class, which provides a functional API consisting of three main methods:
20+
21+ * ``from_model_output `` to create the metric dataclass from inputs
22+ * ``merge `` to update the results
23+ * ``compute `` to get the final result
1024
1125.. code-block ::
1226
@@ -30,11 +44,35 @@ Metrics are based on `clu.Metric`.
3044 # Get result:
3145 result = metric.compute()
3246
47+ Integrate into your training loop
48+ ---------------------------------
3349
34- Metrax API
35- ==========
50+ All Metrax metrics are jittable (they can be used within a ``jax.jit ``
51+ function). If your custom metric uses standard JAX operations and no dynamic
52+ shapes, it should be jittable. You can test with the following:
3653
37- .. toctree ::
38- :maxdepth: 2
54+ .. code-block ::
55+
56+ logits = jnp.ones((2, 3))
57+ labels = jnp.ones((2, 3))
58+ jax.jit(metrics.MSE.from_model_output)(logits, labels)
59+
60+ Jittable metrics can be added directly to your train or eval step.
61+ Non-jittable metrics need to go outside the jitted function.
62+
63+ .. code-block ::
64+
65+ @jax.jit
66+ def eval_step(logits, labels):
67+ ...
68+ outputs['mse'] = metrics.MSE.from_model_output(logits, labels)
69+ outputs['rmse'] = metrics.RMSE.from_model_output(logits, labels)
70+ return outputs
71+
72+ def run_eval():
73+ for logits, labels in eval_dataset:
74+ # Jittable metrics
75+ outputs = eval_step(logits, labels)
76+ # Non-jittable metrics
77+ outputs['sequence_match'] = metrics.SequenceMatch.from_model_outputs(logits, labels)
3978
40- metrax API <metrax >
0 commit comments