-
Notifications
You must be signed in to change notification settings - Fork 0
Spatial tile-based training for memory efficiency #37
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: hanaol/setup-ptlightning
Are you sure you want to change the base?
Conversation
- Add patch_size parameter to RhoData for extracting random patches - Use torch.roll for periodic boundary handling during patch extraction - Train on patches (configurable size), validate on full volume - Move validation to CPU to avoid OOM on full-volume inference Reduces GPU memory usage proportionally to (patch_size/full_size)³, enabling training on larger volumes that previously caused OOM. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
| rho_type: chgcar or elfcar. | ||
| data_size: target size of data. | ||
| label_size: target size of label. | ||
| pyrho_uf: pyrho upsampling factor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These were previously removed from the function, simply updating the docstring to current parameters.
| shift_w = int(self.rng.integers(0, W)) | ||
|
|
||
| data = torch.roll(data, shifts=(shift_d, shift_h, shift_w), dims=(-3, -2, -1)) | ||
| label = torch.roll(label, shifts=(shift_d, shift_h, shift_w), dims=(-3, -2, -1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.roll creates a copy before slicing. If we continue to have memory issues, we can further optimize this with an in-place operation, but this should be sufficient for now.
|
|
||
| # Extract tile from origin | ||
| data = data[..., :size, :size, :size] | ||
| label = label[..., :size, :size, :size] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to think further about how to handle unit cells that are smaller than the tile size. I'm currently leaning towards padding them to fill out the tile. Could also enforce that all unit cells must be larger than the tile size.
Other thing I need to consider is how the wrapping in the convolutional layers works and make sure we account for the boundaries still wrapping correctly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we internally create a supercell of any unit cell that is smaller than the tile size. We can then transform it back down when returning a density to the user.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds reasonable. For validation and inference, we're using the full volume at the moment rather than tiling, so we don't need to worry about the transforming it back to the original dimensions quite yet. Good to keep in mind if/when we do need tiling for inference.
d02539c to
d11a763
Compare
Summary
In order to constrain memory usage, this PR introduces a tiling or patch approach where instead of using the entire material structure a random section of the structure is selected. Over multiple epochs, the model will see multiple, overlapping sections of the crystal structure, similarly to how data augmentation selects a random orientation.
For now, validation and inference are on the full volume data. This makes it easier to build confidence that that the model is learning correctly. If/when we run into memory issues with validation and inference, we will need to expand our tiling to be able to stitch tiles back together for a full grid prediction.
Changes
Test runs
I ran a few test runs with and without tiling and with various sized tiles (16, 32, and 64) on the half grid data. See W&B report here.
As expected, the runs with tiles take longer to learn, as can be seen with the 16 and 32 grid size runs. Still, they do show that they are learning.
patch_sizecan be optimized as a hyperparameter based on the compute we use and our material set.