Skip to content

feat: unify edge_degree + layer radial MLPs into single batched GEMM#1849

Draft
misko wants to merge 3 commits intomainfrom
unified_radial_edge_degree
Draft

feat: unify edge_degree + layer radial MLPs into single batched GEMM#1849
misko wants to merge 3 commits intomainfrom
unified_radial_edge_degree

Conversation

@misko
Copy link
Copy Markdown
Contributor

@misko misko commented Mar 4, 2026

UnifiedRadialMLP consolidates edge_degree_embedding.rad_func and all layer rad_funcs into a single first-layer GEMM, reducing kernel launches and improving GPU utilization.

Key changes:

  • UnifiedRadialMLP: batches first linear layer, processes tails separately
  • get_unified_radial_emb: returns [edge_degree_out, layer_0_out, ...]
  • rad_func=None sentinel: signals precomputed radials in EdgeDegreeEmbedding
  • Fast backends (UMASFastPytorchBackend, UMASFastGPUBackend) create and use UnifiedRadialMLP at prepare_model_for_inference time

Also includes torch.compile compatibility fixes:

  • ChgSpinEmbedding: replaced dict lookup with tensor arithmetic

UnifiedRadialMLP consolidates edge_degree_embedding.rad_func and all layer
rad_funcs into a single first-layer GEMM, reducing kernel launches and
improving GPU utilization.

Key changes:
- UnifiedRadialMLP: batches edge_degree + layer first linear layers into
  single GEMM, processes tails separately
- get_unified_radial_emb: returns [edge_degree_out, layer_0_out, ...]
- rad_func=None sentinel: signals precomputed radials in EdgeDegreeEmbedding
- Fast backends create UnifiedRadialMLP at prepare_model_for_inference time

Also includes torch.compile compatibility fix:
- ChgSpinEmbedding: replaced dict lookup with tensor arithmetic to avoid
  graph break from x.tolist()

Performance: ~16 QPS on 2000 atoms (H200), forces match baseline.
@meta-cla meta-cla Bot added the cla signed label Mar 4, 2026
@misko misko added enhancement New feature or request minor Minor version release labels Mar 4, 2026
The rescale_factor position change (dividing wigner_inv before bmm
instead of after) introduces small floating-point variations.
Increase absolute tolerance from 1e-6 to 2e-6 to accommodate this.
@misko misko force-pushed the unified_radial_edge_degree branch from 7c12230 to ebf7b78 Compare March 4, 2026 19:11
@github-actions
Copy link
Copy Markdown

This PR has been marked as stale because it has been open for 30 days with no activity.

@github-actions github-actions Bot added the stale label Apr 16, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed enhancement New feature or request minor Minor version release stale

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant