Skip to content

Correctly track train / validation losses #485

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jun 1, 2025
Merged

Conversation

LarsKue
Copy link
Contributor

@LarsKue LarsKue commented May 22, 2025

This PR:

Future TODO:

  • We have to reenable tracking custom metrics by keeping a keras.metrics.Mean tracker object on the approximator for each custom metric the user passes. We also need to call update_state on each of those trackers in the compute_metrics or train/test_step method.

@vpratz Do you think you have capacity to take care of the TODO?

@LarsKue LarsKue self-assigned this May 22, 2025
@LarsKue LarsKue added the fix Pull request that fixes a bug label May 22, 2025
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR refines loss tracking by simplifying metric computation and ensuring that the validation loss is correctly tracked.

  • Updated continuous approximator to return only the computed loss.
  • Modified torch and tensorflow approximators to update the loss tracker during the validation step.

Reviewed Changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.

File Description
bayesflow/approximators/continuous_approximator.py Simplified compute_metrics by removing nested metrics.
bayesflow/approximators/backend_approximators/torch_approximator.py Updated test_step to update the loss tracker.
bayesflow/approximators/backend_approximators/tensorflow_approximator.py Updated test_step to update the loss tracker.

Copy link

codecov bot commented May 22, 2025

@elseml elseml mentioned this pull request May 23, 2025
@vpratz
Copy link
Collaborator

vpratz commented May 23, 2025

Do I read it correctly that this PR does not modify the shuffling and therefore not fix this aspect of #481?

Regarding the TODO: I'm not familiar with this part of the code base yet, so I cannot really judge how complicated it is and when I will find the time, but I can try to a closer look at it in the coming weeks.

@stefanradev93
Copy link
Contributor

Shall we still merge this PR, as it provides a pretty important fix and open another one pertaining to general metrics and aggregation.

@vpratz
Copy link
Collaborator

vpratz commented May 24, 2025

I'm still a bit confused by the description. Could you quickly re-explain which problem this PR addresses? I'm currently unable to judge whether the fix is worth a temporary regression until we get to the TODO.

@stefanradev93
Copy link
Contributor

It fixes the incorrect calculation and display of validation losses only on the last validation batch. Additionally, it removes the duplicate printing of loss, inference_net/loss, val_loss / val_inference_net/loss.

@vpratz
Copy link
Collaborator

vpratz commented May 24, 2025

Ahh, thanks a lot! I will do a proper review in the next days, then. Or if you have already thoroughly tested the changes, you might also merge already and open an issue regarding the tracking of custom metrics.

@vpratz
Copy link
Collaborator

vpratz commented May 27, 2025

I now understand better what is going on, and the changes look good to me. I will try to add the trackers for the other metrics today, and let you know if I encounter any difficulties...

@vpratz
Copy link
Collaborator

vpratz commented May 27, 2025

Quick question @LarsKue @stefanradev93 : The individual losses were removed in the PR. I would add them as metrics in the case that more than one loss is present (not if there is only one, to avoid the useless duplicate), so that they are individually tracked. This will also display them, which I think is desirable. Do you agree, or would you rather not display that information?

@vpratz
Copy link
Collaborator

vpratz commented May 27, 2025

Short update: I have succeeded in tracking the metrics, but the serialization for custom inference_metrics and summary_metrics seems to be broken (at least in TensorFlow), and might need some restructuring to work nicely. It might make sense to move this to a separate PR though, I'm currently checking if the issues are separate or interdependent...

@stefanradev93
Copy link
Contributor

Quick question @LarsKue @stefanradev93 : The individual losses were removed in the PR. I would add them as metrics in the case that more than one loss is present (not if there is only one, to avoid the useless duplicate), so that they are individually tracked. This will also display them, which I think is desirable. Do you agree, or would you rather not display that information?

I think this makes sense!

@vpratz vpratz removed their request for review May 27, 2025 15:20
@vpratz
Copy link
Collaborator

vpratz commented May 27, 2025

@LarsKue @stefanradev93 I think the changes are ready to review, the issues regarding custom metrics are (as far as I can tell) not related to the changes in this PR.

@vpratz
Copy link
Collaborator

vpratz commented May 29, 2025

I was able to resolve this properly. The reason for the order-dependency was taking the unweighted mean of means. So in the averaged metrics, average values obtained from batches with batch size had the same weight as values obtained from batch size 2.
By setting the sample_weight argument of Metric.update_state to the batch size, we can obtain the correct averages. This requires calculating the batch sizes from the data at hand, which is approximator-specific. Therefore I outsourced this calculation into a private method that each approximator has to overwrite. If you have different design-ideas for this, please let me know.

This fixes #481 and supersedes that aspect of #482.
@LarsKue @stefanradev93 Could you please take another look?

Note: tqdm will report wrong values for the training loss, as it does naive sampling, the value stored in history will be correct, though.

Code to reproduce
import bayesflow as bf
import keras

workflow = bf.BasicWorkflow(
    inference_network=bf.networks.CouplingFlow(subnet_kwargs={"dropout": 0.0}),
    inference_variables=["parameters"],
    inference_conditions=["observables"],
    simulator=bf.simulators.GaussianMixture(),
    initial_learning_rate=0.0,
    standardize=[]
)

training_data = workflow.simulate(66)
validation_data = workflow.simulate(66)
history = workflow.fit_offline(
    data=training_data,
    epochs=3,
    batch_size=32,
    validation_data=validation_data,
)
print(history.history)

Copy link
Contributor

@stefanradev93 stefanradev93 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Valentin and Lars, I buffed up the docs for the new changes. I also decided to repeat _batch_size_from_data in the backend-specific classes to avoid confusion, since the method is already explicitly needed in these classes. Other than that, I think this PR is ready to merge.

@stefanradev93 stefanradev93 merged commit 996a700 into dev Jun 1, 2025
8 of 9 checks passed
@stefanradev93 stefanradev93 deleted the track-losses branch June 1, 2025 14:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
fix Pull request that fixes a bug
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants