Skip to content

Commit 996a700

Browse files
LarsKuevpratzstefanradev93
authored
Correctly track train / validation losses (#485)
* correctly track train / validation losses * remove mmd from two moons test * reenable metrics in continuous approximator, add trackers * readd custom metrics to two_moons test * take batch size into account when aggregating metrics * Add docs to backend approximator interfaces * Add small doc improvements * Fix typehints to docs. --------- Co-authored-by: Valentin Pratz <[email protected]> Co-authored-by: stefanradev93 <[email protected]>
1 parent 01aadf1 commit 996a700

File tree

7 files changed

+444
-32
lines changed

7 files changed

+444
-32
lines changed

bayesflow/approximators/approximator.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from collections.abc import Mapping
2-
31
import multiprocessing as mp
42

53
import keras
@@ -22,7 +20,7 @@ def build_adapter(cls, **kwargs) -> Adapter:
2220
# implemented by each respective architecture
2321
raise NotImplementedError
2422

25-
def build_from_data(self, data: Mapping[str, any]) -> None:
23+
def build_from_data(self, data: dict[str, any]) -> None:
2624
self.compute_metrics(**filter_kwargs(data, self.compute_metrics), stage="training")
2725
self.built = True
2826

bayesflow/approximators/backend_approximators/jax_approximator.py

Lines changed: 170 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,40 @@
55

66

77
class JAXApproximator(keras.Model):
8+
"""
9+
Base class for approximators using JAX and Keras' stateless training interface.
10+
11+
This class enables stateless training and evaluation steps with JAX, supporting
12+
JAX-compatible gradient computation and variable updates through the `StatelessScope`.
13+
14+
Notes
15+
-----
16+
Subclasses must implement:
17+
- compute_metrics(self, *args, **kwargs) -> dict[str, jax.Array]
18+
- _batch_size_from_data(self, data: dict[str, any]) -> int
19+
"""
20+
821
# noinspection PyMethodOverriding
922
def compute_metrics(self, *args, **kwargs) -> dict[str, jax.Array]:
10-
# implemented by each respective architecture
23+
"""
24+
Compute and return a dictionary of metrics for the current batch.
25+
26+
This method is expected to be implemented by each subclass to compute
27+
task-specific metrics using JAX arrays. It is compatible with stateless
28+
execution and must be differentiable under JAX's `grad` system.
29+
30+
Parameters
31+
----------
32+
*args : tuple
33+
Positional arguments passed to the metric computation function.
34+
**kwargs
35+
Keyword arguments passed to the metric computation function.
36+
37+
Returns
38+
-------
39+
dict of str to jax.Array
40+
Dictionary containing named metric values as JAX arrays.
41+
"""
1142
raise NotImplementedError
1243

1344
def stateless_compute_metrics(
@@ -19,17 +50,34 @@ def stateless_compute_metrics(
1950
stage: str = "training",
2051
) -> (jax.Array, tuple):
2152
"""
22-
Things we do for jax:
23-
1. Accept trainable variables as the first argument
24-
(can be at any position as indicated by the argnum parameter
25-
in autograd, but needs to be an explicit arg)
26-
2. Accept, potentially modify, and return other state variables
27-
3. Return just the loss tensor as the first value
28-
4. Return all other values in a tuple as the second value
29-
30-
This ensures:
31-
1. The function is stateless
32-
2. The function can be differentiated with jax autograd
53+
Stateless computation of metrics required for JAX autograd.
54+
55+
This method performs a stateless forward pass using the given model
56+
variables and returns both the loss and auxiliary information for
57+
further updates.
58+
59+
Parameters
60+
----------
61+
trainable_variables : Any
62+
Current values of the trainable weights.
63+
non_trainable_variables : Any
64+
Current values of non-trainable variables (e.g., batch norm statistics).
65+
metrics_variables : Any
66+
Current values of metric tracking variables.
67+
data : dict of str to any
68+
Input data dictionary passed to `compute_metrics`.
69+
stage : str, default="training"
70+
Whether the computation is for "training" or "validation".
71+
72+
Returns
73+
-------
74+
loss : jax.Array
75+
Scalar loss tensor for gradient computation.
76+
aux : tuple
77+
Tuple containing:
78+
- metrics (dict of str to jax.Array)
79+
- updated non-trainable variables
80+
- updated metrics variables
3381
"""
3482
state_mapping = []
3583
state_mapping.extend(zip(self.trainable_variables, trainable_variables))
@@ -48,19 +96,55 @@ def stateless_compute_metrics(
4896
return metrics["loss"], (metrics, non_trainable_variables, metrics_variables)
4997

5098
def stateless_test_step(self, state: tuple, data: dict[str, any]) -> (dict[str, jax.Array], tuple):
99+
"""
100+
Stateless validation step compatible with JAX.
101+
102+
Parameters
103+
----------
104+
state : tuple
105+
Tuple of (trainable_variables, non_trainable_variables, metrics_variables).
106+
data : dict of str to any
107+
Input data for validation.
108+
109+
Returns
110+
-------
111+
metrics : dict of str to jax.Array
112+
Dictionary of computed evaluation metrics.
113+
state : tuple
114+
Updated state tuple after evaluation.
115+
"""
51116
trainable_variables, non_trainable_variables, metrics_variables = state
52117

53118
loss, aux = self.stateless_compute_metrics(
54119
trainable_variables, non_trainable_variables, metrics_variables, data=data, stage="validation"
55120
)
56121
metrics, non_trainable_variables, metrics_variables = aux
57122

58-
metrics_variables = self._update_loss(loss, metrics_variables)
123+
metrics_variables = self._update_metrics(loss, metrics_variables, self._batch_size_from_data(data))
59124

60125
state = trainable_variables, non_trainable_variables, metrics_variables
61126
return metrics, state
62127

63128
def stateless_train_step(self, state: tuple, data: dict[str, any]) -> (dict[str, jax.Array], tuple):
129+
"""
130+
Stateless training step compatible with JAX autograd and stateless optimization.
131+
132+
Computes gradients and applies optimizer updates in a purely functional style.
133+
134+
Parameters
135+
----------
136+
state : tuple
137+
Tuple of (trainable_variables, non_trainable_variables, optimizer_variables, metrics_variables).
138+
data : dict of str to any
139+
Input data for training.
140+
141+
Returns
142+
-------
143+
metrics : dict of str to jax.Array
144+
Dictionary of computed training metrics.
145+
state : tuple
146+
Updated state tuple after training.
147+
"""
64148
trainable_variables, non_trainable_variables, optimizer_variables, metrics_variables = state
65149

66150
grad_fn = jax.value_and_grad(self.stateless_compute_metrics, has_aux=True)
@@ -74,23 +158,92 @@ def stateless_train_step(self, state: tuple, data: dict[str, any]) -> (dict[str,
74158
optimizer_variables, grads, trainable_variables
75159
)
76160

77-
metrics_variables = self._update_loss(loss, metrics_variables)
161+
metrics_variables = self._update_metrics(loss, metrics_variables, self._batch_size_from_data(data))
78162

79163
state = trainable_variables, non_trainable_variables, optimizer_variables, metrics_variables
80164
return metrics, state
81165

82166
def test_step(self, *args, **kwargs):
167+
"""
168+
Alias to `stateless_test_step` for compatibility with `keras.Model`.
169+
170+
Parameters
171+
----------
172+
*args, **kwargs : Any
173+
Passed through to `stateless_test_step`.
174+
175+
Returns
176+
-------
177+
See `stateless_test_step`.
178+
"""
83179
return self.stateless_test_step(*args, **kwargs)
84180

85181
def train_step(self, *args, **kwargs):
182+
"""
183+
Alias to `stateless_train_step` for compatibility with `keras.Model`.
184+
185+
Parameters
186+
----------
187+
*args, **kwargs : Any
188+
Passed through to `stateless_train_step`.
189+
190+
Returns
191+
-------
192+
See `stateless_train_step`.
193+
"""
86194
return self.stateless_train_step(*args, **kwargs)
87195

88-
def _update_loss(self, loss: jax.Array, metrics_variables: any) -> any:
89-
# update the loss progress bar, and possibly metrics variables along with it
196+
def _update_metrics(self, loss: jax.Array, metrics_variables: any, sample_weight: any = None) -> any:
197+
"""
198+
Updates metric tracking variables in a stateless JAX-compatible way.
199+
200+
This method updates the loss tracker (and any other Keras metrics)
201+
and returns updated metric variable states for downstream use.
202+
203+
Parameters
204+
----------
205+
loss : jax.Array
206+
Scalar loss used for metric tracking.
207+
metrics_variables : Any
208+
Current metric variable states.
209+
sample_weight : Any, optional
210+
Sample weights to apply during update.
211+
212+
Returns
213+
-------
214+
metrics_variables : Any
215+
Updated metrics variable states.
216+
"""
90217
state_mapping = list(zip(self.metrics_variables, metrics_variables))
91218
with keras.StatelessScope(state_mapping) as scope:
92-
self._loss_tracker.update_state(loss)
219+
self._loss_tracker.update_state(loss, sample_weight=sample_weight)
93220

221+
# JAX is stateless, so we need to return the metrics as state in downstream functions
94222
metrics_variables = [scope.get_current_value(v) for v in self.metrics_variables]
95223

96224
return metrics_variables
225+
226+
# noinspection PyMethodOverriding
227+
def _batch_size_from_data(self, data: any) -> int:
228+
"""Obtain the batch size from a batch of data.
229+
230+
To properly weigh the metrics for batches of different sizes, the batch size of a given batch of data is
231+
required. As the data structure differs between approximators, each concrete approximator has to specify
232+
this method.
233+
234+
Parameters
235+
----------
236+
data :
237+
The data that are passed to `compute_metrics` as keyword arguments.
238+
239+
Returns
240+
-------
241+
batch_size : int
242+
The batch size of the given data.
243+
"""
244+
raise NotImplementedError(
245+
"Correct calculation of the metrics requires obtaining the batch size from the supplied data "
246+
"for proper weighting of metrics for batches with different sizes. Please implement the "
247+
"_batch_size_from_data method for your approximator. For a given batch of data, it should "
248+
"return the corresponding batch size."
249+
)

bayesflow/approximators/backend_approximators/numpy_approximator.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,27 @@ def compute_metrics(self, *args, **kwargs) -> dict[str, np.ndarray]:
1212

1313
def test_step(self, data: dict[str, any]) -> dict[str, np.ndarray]:
1414
kwargs = filter_kwargs(data | {"stage": "validation"}, self.compute_metrics)
15-
return self.compute_metrics(**kwargs)
15+
metrics = self.compute_metrics(**kwargs)
16+
self._update_metrics(metrics, self._batch_size_from_data(data))
17+
return metrics
1618

1719
def train_step(self, data: dict[str, any]) -> dict[str, np.ndarray]:
1820
raise NotImplementedError("Numpy backend does not support training.")
21+
22+
def _update_metrics(self, metrics, sample_weight=None):
23+
for name, value in metrics.items():
24+
try:
25+
metric_index = self.metrics_names.index(name)
26+
self.metrics[metric_index].update_state(value, sample_weight=sample_weight)
27+
except ValueError:
28+
self._metrics.append(keras.metrics.Mean(name=name))
29+
self._metrics[-1].update_state(value, sample_weight=sample_weight)
30+
31+
# noinspection PyMethodOverriding
32+
def _batch_size_from_data(self, data: any) -> int:
33+
raise NotImplementedError(
34+
"Correct calculation of the metrics requires obtaining the batch size from the supplied data "
35+
"for proper weighting of metrics for batches with different sizes. Please implement the "
36+
"_batch_size_from_data method for your approximator. For a given batch of data, it should "
37+
"return the corresponding batch size."
38+
)

0 commit comments

Comments
 (0)