Skip to content

Aggregate logs in evaluate #483

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed

Aggregate logs in evaluate #483

wants to merge 6 commits into from

Conversation

LarsKue
Copy link
Contributor

@LarsKue LarsKue commented May 21, 2025

Fixes #481

@LarsKue LarsKue requested a review from Copilot May 21, 2025 19:56
@LarsKue LarsKue self-assigned this May 21, 2025
@LarsKue LarsKue added the fix Pull request that fixes a bug label May 21, 2025
@LarsKue LarsKue requested a review from stefanradev93 May 21, 2025 19:57
Copy link

codecov bot commented May 21, 2025

Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

Implements cumulative logging and averaging in the evaluate method for all backend approximators to fix metrics aggregation issues.

  • Introduce _aggregate_logs and _mean_logs helpers to accumulate and normalize batch metrics.
  • Override evaluate in Torch, TensorFlow, and JAX approximators using their respective *EpochIterator and callback flows.
  • Wire up Keras CallbackList and ensure per-batch callbacks with aggregated logs.

Reviewed Changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.

File Description
bayesflow/approximators/backend_approximators/torch_approximator.py Added _aggregate_logs, _mean_logs, and updated evaluate with TorchEpochIterator and callbacks.
bayesflow/approximators/backend_approximators/tensorflow_approximator.py Added _aggregate_logs, _mean_logs, and updated evaluate with TFEpochIterator and callbacks.
bayesflow/approximators/backend_approximators/jax_approximator.py Added _aggregate_logs, _mean_logs, and updated evaluate with JAXEpochIterator, state sync, and callbacks.
Comments suppressed due to low confidence (2)

bayesflow/approximators/backend_approximators/jax_approximator.py:87

  • [nitpick] The loop variable 'iterator' shadows the epoch_iterator and may be confusing; consider renaming it to 'batch_data' or similar to clarify its purpose.
for step, iterator in epoch_iterator:

bayesflow/approximators/backend_approximators/tensorflow_approximator.py:31

  • Add unit tests for this new evaluate implementation to verify that log aggregation and averaging behave as expected across multiple batches.
def evaluate(

@LarsKue LarsKue requested a review from Copilot May 21, 2025 20:11
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

Adds an aggregate option to the evaluate method in all backend approximators, allowing batch-wise metrics to be summed and averaged rather than overwritten each step.

  • Introduce _aggregate_fn and _reduce_fn in evaluate to accumulate and average metrics.
  • Add aggregate and return_dict parameters to control output format.
  • Ensure consistency in callback invocation and test function setup across Torch, TensorFlow, and JAX backends.

Reviewed Changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.

File Description
bayesflow/approximators/backend_approximators/torch_approximator.py Added evaluate override with optional aggregation logic
bayesflow/approximators/backend_approximators/tensorflow_approximator.py Added evaluate override with optional aggregation logic
bayesflow/approximators/backend_approximators/jax_approximator.py Added evaluate override with optional aggregation logic and JAX state handling
Comments suppressed due to low confidence (4)

bayesflow/approximators/backend_approximators/torch_approximator.py:26

  • Add unit tests covering both aggregate=True and aggregate=False paths to verify that metrics are correctly summed and averaged.
        aggregate=False,

bayesflow/approximators/backend_approximators/torch_approximator.py:29

  • Add a call to self._assert_compile_called("evaluate") at the start of evaluate to ensure the model has been compiled before evaluation.
# TODO: respect compiled trainable state

bayesflow/approximators/backend_approximators/jax_approximator.py:26

  • The default aggregate=True in JAX differs from aggregate=False in the Torch and TensorFlow backends. Align the default value for consistency across backends.
        aggregate=True,

bayesflow/approximators/backend_approximators/torch_approximator.py:16

  • This new evaluate method lacks docstrings for the aggregate and return_dict parameters; please add descriptions and expected behavior.
    def evaluate(

LarsKue added 4 commits May 21, 2025 16:14
… aggregate-logs-in-evaluate

# Conflicts:
#	bayesflow/approximators/backend_approximators/jax_approximator.py
#	bayesflow/approximators/backend_approximators/tensorflow_approximator.py
#	bayesflow/approximators/backend_approximators/torch_approximator.py
@LarsKue
Copy link
Contributor Author

LarsKue commented May 22, 2025

Superceded by #485

@LarsKue LarsKue closed this May 22, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
fix Pull request that fixes a bug
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant