Skip to content

Conversation

@rousseab
Copy link
Collaborator

I refactored how the Kolmogorov Smirnov (KS) values are computed. This is now done in the model lightning module instead of in a callback. This is necessary in order to log these values and use them for checkpoint monitoring and/or early stopping (the call order of callbacks / model hooks is indertermined, so we can't rely on a callback to log metrics). We now get more telemetry during training on just what the heck is going on (Kudos on the structure factor, btw, it's very nice).

Screenshot 2024-09-23 at 8 03 45 AM

What used to be the sampling callback has been refactored to visualize and write to disk what is already calculated in the model lightning module. This relies on internal properties (ie, "states"), which is not super clean, but it's the PL way.

I also refactor the configuration file in a way that breaks backward compatibility. I didn't not fix every example / sanity check / off-broadway code bits we have. See the following file for an example:
examples/config_files/diffusion/config_diffusion_mlp.yaml

Future possible improvements might be to compute the validation set distances and energies only once (write them in a tmp cache file perhaps) instead of wasting time redoing this at every epoch.

@rousseab rousseab mentioned this pull request Sep 24, 2024

def training_step(self, batch, batch_idx):
"""Runs a prediction step for training, returning the loss."""
logger.info(f" - Starting training step with batch index {batch_idx}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as discussed - probably better to remove

return output

if self.metrics_parameters.compute_energies:
logger.info(" * registering reference energies")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as for training step logger.info

self.energy_ks_metric.register_reference_samples(reference_energies.cpu())

if self.metrics_parameters.compute_structure_factor:
logger.info(" * registering reference distances")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove,debug or optional (i.e. remove)

reference_distances.cpu()
)

logger.info(f" Done validation step with batch index {batch_idx}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

@rousseab rousseab merged commit e9f18cc into main Sep 26, 2024
1 check passed
@rousseab rousseab deleted the checkpointing_on_structure_factor branch September 26, 2024 12:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants