-
Notifications
You must be signed in to change notification settings - Fork 69
Correctly track train / validation losses #485
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR refines loss tracking by simplifying metric computation and ensuring that the validation loss is correctly tracked.
- Updated continuous approximator to return only the computed loss.
- Modified torch and tensorflow approximators to update the loss tracker during the validation step.
Reviewed Changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.
File | Description |
---|---|
bayesflow/approximators/continuous_approximator.py | Simplified compute_metrics by removing nested metrics. |
bayesflow/approximators/backend_approximators/torch_approximator.py | Updated test_step to update the loss tracker. |
bayesflow/approximators/backend_approximators/tensorflow_approximator.py | Updated test_step to update the loss tracker. |
Do I read it correctly that this PR does not modify the shuffling and therefore not fix this aspect of #481? Regarding the TODO: I'm not familiar with this part of the code base yet, so I cannot really judge how complicated it is and when I will find the time, but I can try to a closer look at it in the coming weeks. |
Shall we still merge this PR, as it provides a pretty important fix and open another one pertaining to general metrics and aggregation. |
I'm still a bit confused by the description. Could you quickly re-explain which problem this PR addresses? I'm currently unable to judge whether the fix is worth a temporary regression until we get to the TODO. |
It fixes the incorrect calculation and display of validation losses only on the last validation batch. Additionally, it removes the duplicate printing of loss, inference_net/loss, val_loss / val_inference_net/loss. |
Ahh, thanks a lot! I will do a proper review in the next days, then. Or if you have already thoroughly tested the changes, you might also merge already and open an issue regarding the tracking of custom metrics. |
I now understand better what is going on, and the changes look good to me. I will try to add the trackers for the other metrics today, and let you know if I encounter any difficulties... |
Quick question @LarsKue @stefanradev93 : The individual losses were removed in the PR. I would add them as metrics in the case that more than one loss is present (not if there is only one, to avoid the useless duplicate), so that they are individually tracked. This will also display them, which I think is desirable. Do you agree, or would you rather not display that information? |
Short update: I have succeeded in tracking the metrics, but the serialization for custom |
I think this makes sense! |
@LarsKue @stefanradev93 I think the changes are ready to review, the issues regarding custom metrics are (as far as I can tell) not related to the changes in this PR. |
I was able to resolve this properly. The reason for the order-dependency was taking the unweighted mean of means. So in the averaged metrics, average values obtained from batches with batch size had the same weight as values obtained from batch size 2. This fixes #481 and supersedes that aspect of #482. Note: Code to reproduceimport bayesflow as bf
import keras
workflow = bf.BasicWorkflow(
inference_network=bf.networks.CouplingFlow(subnet_kwargs={"dropout": 0.0}),
inference_variables=["parameters"],
inference_conditions=["observables"],
simulator=bf.simulators.GaussianMixture(),
initial_learning_rate=0.0,
standardize=[]
)
training_data = workflow.simulate(66)
validation_data = workflow.simulate(66)
history = workflow.fit_offline(
data=training_data,
epochs=3,
batch_size=32,
validation_data=validation_data,
)
print(history.history) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Valentin and Lars, I buffed up the docs for the new changes. I also decided to repeat _batch_size_from_data
in the backend-specific classes to avoid confusion, since the method is already explicitly needed in these classes. Other than that, I think this PR is ready to merge.
This PR:
evaluate
#483test_step
loss/..._loss
and other nested metricsFuture TODO:
keras.metrics.Mean
tracker object on the approximator for each custom metric the user passes. We also need to callupdate_state
on each of those trackers in thecompute_metrics
ortrain/test_step
method.@vpratz Do you think you have capacity to take care of the TODO?