Skip to content

Conversation

@forklady42
Copy link
Collaborator

Summary

In order to constrain memory usage, this PR introduces a tiling or patch approach where instead of using the entire material structure a random section of the structure is selected. Over multiple epochs, the model will see multiple, overlapping sections of the crystal structure, similarly to how data augmentation selects a random orientation.

For now, validation and inference are on the full volume data. This makes it easier to build confidence that that the model is learning correctly. If/when we run into memory issues with validation and inference, we will need to expand our tiling to be able to stitch tiles back together for a full grid prediction.

Changes

  • Add patch_size parameter to RhoData for extracting random patches
  • Use torch.roll for periodic boundary handling during patch extraction
  • Train on patches (configurable size), validate on full volume

Test runs

I ran a few test runs with and without tiling and with various sized tiles (16, 32, and 64) on the half grid data. See W&B report here.

As expected, the runs with tiles take longer to learn, as can be seen with the 16 and 32 grid size runs. Still, they do show that they are learning. patch_size can be optimized as a hyperparameter based on the compute we use and our material set.

Hananeh Oliaei and others added 9 commits November 13, 2025 15:24
- Add patch_size parameter to RhoData for extracting random patches
- Use torch.roll for periodic boundary handling during patch extraction
- Train on patches (configurable size), validate on full volume
- Move validation to CPU to avoid OOM on full-volume inference

Reduces GPU memory usage proportionally to (patch_size/full_size)³,
enabling training on larger volumes that previously caused OOM.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>
rho_type: chgcar or elfcar.
data_size: target size of data.
label_size: target size of label.
pyrho_uf: pyrho upsampling factor
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These were previously removed from the function, simply updating the docstring to current parameters.

shift_w = int(self.rng.integers(0, W))

data = torch.roll(data, shifts=(shift_d, shift_h, shift_w), dims=(-3, -2, -1))
label = torch.roll(label, shifts=(shift_d, shift_h, shift_w), dims=(-3, -2, -1))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.roll creates a copy before slicing. If we continue to have memory issues, we can further optimize this with an in-place operation, but this should be sufficient for now.


# Extract tile from origin
data = data[..., :size, :size, :size]
label = label[..., :size, :size, :size]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to think further about how to handle unit cells that are smaller than the tile size. I'm currently leaning towards padding them to fill out the tile. Could also enforce that all unit cells must be larger than the tile size.

Other thing I need to consider is how the wrapping in the convolutional layers works and make sure we account for the boundaries still wrapping correctly.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we internally create a supercell of any unit cell that is smaller than the tile size. We can then transform it back down when returning a density to the user.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds reasonable. For validation and inference, we're using the full volume at the moment rather than tiling, so we don't need to worry about the transforming it back to the original dimensions quite yet. Good to keep in mind if/when we do need tiling for inference.

@hanaol hanaol force-pushed the hanaol/setup-ptlightning branch from d02539c to d11a763 Compare December 11, 2025 17:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants