-
Notifications
You must be signed in to change notification settings - Fork 2
Checkpointing on structure factor #79
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This refactor lets the lightning module compute the various ks metrics, which used to be computed in a callback. The callback now just records to disck / plots the results. The configuration has also been modified to reflect these changes.
|
||
def training_step(self, batch, batch_idx): | ||
"""Runs a prediction step for training, returning the loss.""" | ||
logger.info(f" - Starting training step with batch index {batch_idx}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as discussed - probably better to remove
return output | ||
|
||
if self.metrics_parameters.compute_energies: | ||
logger.info(" * registering reference energies") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as for training step logger.info
self.energy_ks_metric.register_reference_samples(reference_energies.cpu()) | ||
|
||
if self.metrics_parameters.compute_structure_factor: | ||
logger.info(" * registering reference distances") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove,debug or optional (i.e. remove)
reference_distances.cpu() | ||
) | ||
|
||
logger.info(f" Done validation step with batch index {batch_idx}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same
I refactored how the Kolmogorov Smirnov (KS) values are computed. This is now done in the model lightning module instead of in a callback. This is necessary in order to log these values and use them for checkpoint monitoring and/or early stopping (the call order of callbacks / model hooks is indertermined, so we can't rely on a callback to log metrics). We now get more telemetry during training on just what the heck is going on (Kudos on the structure factor, btw, it's very nice).
What used to be the sampling callback has been refactored to visualize and write to disk what is already calculated in the model lightning module. This relies on internal properties (ie, "states"), which is not super clean, but it's the PL way.
I also refactor the configuration file in a way that breaks backward compatibility. I didn't not fix every example / sanity check / off-broadway code bits we have. See the following file for an example:
examples/config_files/diffusion/config_diffusion_mlp.yaml
Future possible improvements might be to compute the validation set distances and energies only once (write them in a tmp cache file perhaps) instead of wasting time redoing this at every epoch.