Skip to content

Correctly track train / validation losses #485

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jun 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions bayesflow/approximators/approximator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from collections.abc import Mapping

import multiprocessing as mp

import keras
Expand All @@ -22,7 +20,7 @@ def build_adapter(cls, **kwargs) -> Adapter:
# implemented by each respective architecture
raise NotImplementedError

def build_from_data(self, data: Mapping[str, any]) -> None:
def build_from_data(self, data: dict[str, any]) -> None:
self.compute_metrics(**filter_kwargs(data, self.compute_metrics), stage="training")
self.built = True

Expand Down
187 changes: 170 additions & 17 deletions bayesflow/approximators/backend_approximators/jax_approximator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,40 @@


class JAXApproximator(keras.Model):
"""
Base class for approximators using JAX and Keras' stateless training interface.

This class enables stateless training and evaluation steps with JAX, supporting
JAX-compatible gradient computation and variable updates through the `StatelessScope`.

Notes
-----
Subclasses must implement:
- compute_metrics(self, *args, **kwargs) -> dict[str, jax.Array]
- _batch_size_from_data(self, data: dict[str, any]) -> int
"""

# noinspection PyMethodOverriding
def compute_metrics(self, *args, **kwargs) -> dict[str, jax.Array]:
# implemented by each respective architecture
"""
Compute and return a dictionary of metrics for the current batch.

This method is expected to be implemented by each subclass to compute
task-specific metrics using JAX arrays. It is compatible with stateless
execution and must be differentiable under JAX's `grad` system.

Parameters
----------
*args : tuple
Positional arguments passed to the metric computation function.
**kwargs
Keyword arguments passed to the metric computation function.

Returns
-------
dict of str to jax.Array
Dictionary containing named metric values as JAX arrays.
"""
raise NotImplementedError

def stateless_compute_metrics(
Expand All @@ -19,17 +50,34 @@
stage: str = "training",
) -> (jax.Array, tuple):
"""
Things we do for jax:
1. Accept trainable variables as the first argument
(can be at any position as indicated by the argnum parameter
in autograd, but needs to be an explicit arg)
2. Accept, potentially modify, and return other state variables
3. Return just the loss tensor as the first value
4. Return all other values in a tuple as the second value

This ensures:
1. The function is stateless
2. The function can be differentiated with jax autograd
Stateless computation of metrics required for JAX autograd.

This method performs a stateless forward pass using the given model
variables and returns both the loss and auxiliary information for
further updates.

Parameters
----------
trainable_variables : Any
Current values of the trainable weights.
non_trainable_variables : Any
Current values of non-trainable variables (e.g., batch norm statistics).
metrics_variables : Any
Current values of metric tracking variables.
data : dict of str to any
Input data dictionary passed to `compute_metrics`.
stage : str, default="training"
Whether the computation is for "training" or "validation".

Returns
-------
loss : jax.Array
Scalar loss tensor for gradient computation.
aux : tuple
Tuple containing:
- metrics (dict of str to jax.Array)
- updated non-trainable variables
- updated metrics variables
"""
state_mapping = []
state_mapping.extend(zip(self.trainable_variables, trainable_variables))
Expand All @@ -48,19 +96,55 @@
return metrics["loss"], (metrics, non_trainable_variables, metrics_variables)

def stateless_test_step(self, state: tuple, data: dict[str, any]) -> (dict[str, jax.Array], tuple):
"""
Stateless validation step compatible with JAX.

Parameters
----------
state : tuple
Tuple of (trainable_variables, non_trainable_variables, metrics_variables).
data : dict of str to any
Input data for validation.

Returns
-------
metrics : dict of str to jax.Array
Dictionary of computed evaluation metrics.
state : tuple
Updated state tuple after evaluation.
"""
trainable_variables, non_trainable_variables, metrics_variables = state

loss, aux = self.stateless_compute_metrics(
trainable_variables, non_trainable_variables, metrics_variables, data=data, stage="validation"
)
metrics, non_trainable_variables, metrics_variables = aux

metrics_variables = self._update_loss(loss, metrics_variables)
metrics_variables = self._update_metrics(loss, metrics_variables, self._batch_size_from_data(data))

state = trainable_variables, non_trainable_variables, metrics_variables
return metrics, state

def stateless_train_step(self, state: tuple, data: dict[str, any]) -> (dict[str, jax.Array], tuple):
"""
Stateless training step compatible with JAX autograd and stateless optimization.

Computes gradients and applies optimizer updates in a purely functional style.

Parameters
----------
state : tuple
Tuple of (trainable_variables, non_trainable_variables, optimizer_variables, metrics_variables).
data : dict of str to any
Input data for training.

Returns
-------
metrics : dict of str to jax.Array
Dictionary of computed training metrics.
state : tuple
Updated state tuple after training.
"""
trainable_variables, non_trainable_variables, optimizer_variables, metrics_variables = state

grad_fn = jax.value_and_grad(self.stateless_compute_metrics, has_aux=True)
Expand All @@ -74,23 +158,92 @@
optimizer_variables, grads, trainable_variables
)

metrics_variables = self._update_loss(loss, metrics_variables)
metrics_variables = self._update_metrics(loss, metrics_variables, self._batch_size_from_data(data))

state = trainable_variables, non_trainable_variables, optimizer_variables, metrics_variables
return metrics, state

def test_step(self, *args, **kwargs):
"""
Alias to `stateless_test_step` for compatibility with `keras.Model`.

Parameters
----------
*args, **kwargs : Any
Passed through to `stateless_test_step`.

Returns
-------
See `stateless_test_step`.
"""
return self.stateless_test_step(*args, **kwargs)

def train_step(self, *args, **kwargs):
"""
Alias to `stateless_train_step` for compatibility with `keras.Model`.

Parameters
----------
*args, **kwargs : Any
Passed through to `stateless_train_step`.

Returns
-------
See `stateless_train_step`.
"""
return self.stateless_train_step(*args, **kwargs)

def _update_loss(self, loss: jax.Array, metrics_variables: any) -> any:
# update the loss progress bar, and possibly metrics variables along with it
def _update_metrics(self, loss: jax.Array, metrics_variables: any, sample_weight: any = None) -> any:
"""
Updates metric tracking variables in a stateless JAX-compatible way.

This method updates the loss tracker (and any other Keras metrics)
and returns updated metric variable states for downstream use.

Parameters
----------
loss : jax.Array
Scalar loss used for metric tracking.
metrics_variables : Any
Current metric variable states.
sample_weight : Any, optional
Sample weights to apply during update.

Returns
-------
metrics_variables : Any
Updated metrics variable states.
"""
state_mapping = list(zip(self.metrics_variables, metrics_variables))
with keras.StatelessScope(state_mapping) as scope:
self._loss_tracker.update_state(loss)
self._loss_tracker.update_state(loss, sample_weight=sample_weight)

# JAX is stateless, so we need to return the metrics as state in downstream functions
metrics_variables = [scope.get_current_value(v) for v in self.metrics_variables]

return metrics_variables

# noinspection PyMethodOverriding
def _batch_size_from_data(self, data: any) -> int:
"""Obtain the batch size from a batch of data.

To properly weigh the metrics for batches of different sizes, the batch size of a given batch of data is
required. As the data structure differs between approximators, each concrete approximator has to specify
this method.

Parameters
----------
data :
The data that are passed to `compute_metrics` as keyword arguments.

Returns
-------
batch_size : int
The batch size of the given data.
"""
raise NotImplementedError(

Check warning on line 244 in bayesflow/approximators/backend_approximators/jax_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/jax_approximator.py#L244

Added line #L244 was not covered by tests
"Correct calculation of the metrics requires obtaining the batch size from the supplied data "
"for proper weighting of metrics for batches with different sizes. Please implement the "
"_batch_size_from_data method for your approximator. For a given batch of data, it should "
"return the corresponding batch size."
)
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,27 @@

def test_step(self, data: dict[str, any]) -> dict[str, np.ndarray]:
kwargs = filter_kwargs(data | {"stage": "validation"}, self.compute_metrics)
return self.compute_metrics(**kwargs)
metrics = self.compute_metrics(**kwargs)
self._update_metrics(metrics, self._batch_size_from_data(data))
return metrics

Check warning on line 17 in bayesflow/approximators/backend_approximators/numpy_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/numpy_approximator.py#L15-L17

Added lines #L15 - L17 were not covered by tests

def train_step(self, data: dict[str, any]) -> dict[str, np.ndarray]:
raise NotImplementedError("Numpy backend does not support training.")

def _update_metrics(self, metrics, sample_weight=None):
for name, value in metrics.items():
try:
metric_index = self.metrics_names.index(name)
self.metrics[metric_index].update_state(value, sample_weight=sample_weight)
except ValueError:
self._metrics.append(keras.metrics.Mean(name=name))
self._metrics[-1].update_state(value, sample_weight=sample_weight)

Check warning on line 29 in bayesflow/approximators/backend_approximators/numpy_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/numpy_approximator.py#L22-L29

Added lines #L22 - L29 were not covered by tests

# noinspection PyMethodOverriding
def _batch_size_from_data(self, data: any) -> int:
raise NotImplementedError(

Check warning on line 33 in bayesflow/approximators/backend_approximators/numpy_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/numpy_approximator.py#L32-L33

Added lines #L32 - L33 were not covered by tests
"Correct calculation of the metrics requires obtaining the batch size from the supplied data "
"for proper weighting of metrics for batches with different sizes. Please implement the "
"_batch_size_from_data method for your approximator. For a given batch of data, it should "
"return the corresponding batch size."
)
Loading