-
Notifications
You must be signed in to change notification settings - Fork 69
Aggregate logs in evaluate
#483
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov ReportAttention: Patch coverage is
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Implements cumulative logging and averaging in the evaluate
method for all backend approximators to fix metrics aggregation issues.
- Introduce
_aggregate_logs
and_mean_logs
helpers to accumulate and normalize batch metrics. - Override
evaluate
in Torch, TensorFlow, and JAX approximators using their respective*EpochIterator
and callback flows. - Wire up Keras
CallbackList
and ensure per-batch callbacks with aggregated logs.
Reviewed Changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.
File | Description |
---|---|
bayesflow/approximators/backend_approximators/torch_approximator.py | Added _aggregate_logs , _mean_logs , and updated evaluate with TorchEpochIterator and callbacks. |
bayesflow/approximators/backend_approximators/tensorflow_approximator.py | Added _aggregate_logs , _mean_logs , and updated evaluate with TFEpochIterator and callbacks. |
bayesflow/approximators/backend_approximators/jax_approximator.py | Added _aggregate_logs , _mean_logs , and updated evaluate with JAXEpochIterator, state sync, and callbacks. |
Comments suppressed due to low confidence (2)
bayesflow/approximators/backend_approximators/jax_approximator.py:87
- [nitpick] The loop variable 'iterator' shadows the epoch_iterator and may be confusing; consider renaming it to 'batch_data' or similar to clarify its purpose.
for step, iterator in epoch_iterator:
bayesflow/approximators/backend_approximators/tensorflow_approximator.py:31
- Add unit tests for this new evaluate implementation to verify that log aggregation and averaging behave as expected across multiple batches.
def evaluate(
bayesflow/approximators/backend_approximators/jax_approximator.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Adds an aggregate
option to the evaluate
method in all backend approximators, allowing batch-wise metrics to be summed and averaged rather than overwritten each step.
- Introduce
_aggregate_fn
and_reduce_fn
inevaluate
to accumulate and average metrics. - Add
aggregate
andreturn_dict
parameters to control output format. - Ensure consistency in callback invocation and test function setup across Torch, TensorFlow, and JAX backends.
Reviewed Changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.
File | Description |
---|---|
bayesflow/approximators/backend_approximators/torch_approximator.py | Added evaluate override with optional aggregation logic |
bayesflow/approximators/backend_approximators/tensorflow_approximator.py | Added evaluate override with optional aggregation logic |
bayesflow/approximators/backend_approximators/jax_approximator.py | Added evaluate override with optional aggregation logic and JAX state handling |
Comments suppressed due to low confidence (4)
bayesflow/approximators/backend_approximators/torch_approximator.py:26
- Add unit tests covering both
aggregate=True
andaggregate=False
paths to verify that metrics are correctly summed and averaged.
aggregate=False,
bayesflow/approximators/backend_approximators/torch_approximator.py:29
- Add a call to
self._assert_compile_called("evaluate")
at the start ofevaluate
to ensure the model has been compiled before evaluation.
# TODO: respect compiled trainable state
bayesflow/approximators/backend_approximators/jax_approximator.py:26
- The default
aggregate=True
in JAX differs fromaggregate=False
in the Torch and TensorFlow backends. Align the default value for consistency across backends.
aggregate=True,
bayesflow/approximators/backend_approximators/torch_approximator.py:16
- This new
evaluate
method lacks docstrings for theaggregate
andreturn_dict
parameters; please add descriptions and expected behavior.
def evaluate(
… aggregate-logs-in-evaluate # Conflicts: # bayesflow/approximators/backend_approximators/jax_approximator.py # bayesflow/approximators/backend_approximators/tensorflow_approximator.py # bayesflow/approximators/backend_approximators/torch_approximator.py
Superceded by #485 |
Fixes #481