Skip to content

Commit 08f7b0b

Browse files
authored
Add UnifiedRadialMLP for batched radial computation (#1831)
Adds precomputed per-layer radial embeddings to the umas_fast_pytorch and umas_fast_gpu backends Unfortunately this only speeds up the first layer of the ~3 linear layers we have in each of the radial MLPs since they diverge after However this does give us a 0.8qps (15.5->16.3qps) boost for 2000 carbon system for UMA-S 1.1 when using umas_fast_gpu or umas_fast_pytorch
1 parent d5d954a commit 08f7b0b

7 files changed

Lines changed: 453 additions & 33 deletions

File tree

src/fairchem/core/models/uma/escn_md.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -759,11 +759,18 @@ def forward(self, data_dict: AtomicData) -> dict[str, torch.Tensor]:
759759
###############################################################
760760
# Update spherical node embeddings
761761
###############################################################
762+
763+
# Get edge embeddings for each layer
764+
# General backend: raw x_edge (rad_func computed inside SO2_Convolution)
765+
# Fast backends: precomputed radials
766+
with record_function("layer_radial_emb"):
767+
x_edge_per_layer = self.backend.get_layer_radial_emb(x_edge, self)
768+
762769
for i in range(self.num_layers):
763770
with record_function(f"message passing {i}"):
764771
x_message = self.blocks[i](
765772
x_message,
766-
x_edge,
773+
x_edge_per_layer[i],
767774
graph_dict["edge_index"],
768775
wigner,
769776
wigner_inv_envelope,

src/fairchem/core/models/uma/escn_md_block.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def forward_chunk(
192192
)
193193
x_message, x_0_gating = self.so2_conv_1(x_message, x_edge)
194194
x_message = self.act(x_0_gating, x_message)
195-
x_message = self.so2_conv_2(x_message, x_edge)
195+
x_message = self.so2_conv_2(x_message)
196196
new_embedding = self.backend.permute_wigner_inv_edge_to_node(
197197
x_message,
198198
wigner_inv_envelope,

src/fairchem/core/models/uma/nn/execution_backends.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
import torch
1414

15+
from fairchem.core.models.uma.nn.unified_radial import UnifiedRadialMLP
16+
1517
if TYPE_CHECKING:
1618
from fairchem.core.units.mlip_unit.api.inference import InferenceSettings
1719

@@ -85,6 +87,28 @@ def prepare_model_for_inference(model: torch.nn.Module) -> None:
8587
model: The backbone model to prepare.
8688
"""
8789

90+
@staticmethod
91+
def get_layer_radial_emb(
92+
x_edge: torch.Tensor,
93+
model: torch.nn.Module,
94+
) -> list[torch.Tensor]:
95+
"""
96+
Get edge embeddings for each layer.
97+
98+
Default implementation returns the same raw x_edge for all layers.
99+
SO2_Convolution will compute rad_func(x_edge) internally.
100+
101+
Override in fast backends to precompute radials.
102+
103+
Args:
104+
x_edge: Edge embeddings [E, edge_features]
105+
model: The backbone model
106+
107+
Returns:
108+
List of edge embeddings, one per layer
109+
"""
110+
return [x_edge] * len(model.blocks)
111+
88112
@staticmethod
89113
def prepare_wigner(
90114
wigner: torch.Tensor,
@@ -261,11 +285,13 @@ def validate(
261285
@staticmethod
262286
def prepare_model_for_inference(model: torch.nn.Module) -> None:
263287
"""
264-
Convert SO2_Convolution modules to block-diagonal GEMM variants.
288+
Convert SO2_Convolution modules to block-diagonal GEMM variants
289+
and create unified radial MLP for batched computation.
265290
266291
Replaces so2_conv_1 with SO2_Conv1_WithRadialBlock and
267292
so2_conv_2 with SO2_Conv2_InternalBlock in each block's
268-
Edgewise module.
293+
Edgewise module. Then creates a UnifiedRadialMLP from all
294+
radial functions for efficient batched computation.
269295
"""
270296
from fairchem.core.models.uma.nn.so2_layers import (
271297
convert_so2_conv1,
@@ -276,6 +302,27 @@ def prepare_model_for_inference(model: torch.nn.Module) -> None:
276302
block.edge_wise.so2_conv_1 = convert_so2_conv1(block.edge_wise.so2_conv_1)
277303
block.edge_wise.so2_conv_2 = convert_so2_conv2(block.edge_wise.so2_conv_2)
278304

305+
# Create unified radial MLP for batched computation
306+
rad_funcs = [block.edge_wise.so2_conv_1.rad_func for block in model.blocks]
307+
model._unified_radial_mlp = UnifiedRadialMLP(rad_funcs)
308+
309+
@staticmethod
310+
def get_layer_radial_emb(
311+
x_edge: torch.Tensor,
312+
model: torch.nn.Module,
313+
) -> list[torch.Tensor]:
314+
"""
315+
Compute radial embeddings for all layers using batched UnifiedRadialMLP.
316+
317+
Args:
318+
x_edge: Edge embeddings [E, edge_features]
319+
model: The backbone model with _unified_radial_mlp
320+
321+
Returns:
322+
List of radial embeddings, one per layer [E, radial_features]
323+
"""
324+
return model._unified_radial_mlp(x_edge)
325+
279326

280327
class UMASFastGPUBackend(UMASFastPytorchBackend):
281328
"""

src/fairchem/core/models/uma/nn/so2_layers.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -243,15 +243,15 @@ def forward(
243243
244244
Args:
245245
x: Input features [E, coeffs, channels]
246-
x_edge: Edge embeddings [E, edge_features]
246+
x_edge: Precomputed radial embeddings [E, radial_features]
247247
248248
Returns:
249249
(output, gating): output [E, coeffs, m_output_channels],
250250
gating [E, extra_m0_output_channels]
251251
"""
252-
x_edge_by_m = self.rad_func(x_edge).split(self.edge_split_sizes, dim=1)
252+
x_edge_by_m = x_edge.split(self.edge_split_sizes, dim=1)
253253
x_by_m = x.split(self.m_split_sizes, dim=1)
254-
num_edges = len(x_edge)
254+
num_edges = x.shape[0]
255255

256256
# m=0: apply radial, linear, split gating
257257
x_0 = x_by_m[0].view(num_edges, -1) * x_edge_by_m[0]
@@ -511,18 +511,22 @@ def __init__(
511511
def forward(
512512
self,
513513
x: torch.Tensor,
514-
x_edge: torch.Tensor,
514+
x_edge: torch.Tensor | None = None,
515515
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
516-
# radial function
516+
# Compute radial embedding from raw x_edge if we have external weights
517517
if self.rad_func is not None:
518-
x_edge_by_m = self.rad_func(x_edge).split(self.edge_split_sizes, dim=1)
518+
x_edge = self.rad_func(x_edge)
519519

520520
x_by_m = x.split(self.m_split_sizes, dim=1)
521521

522-
num_edges = len(x_edge)
522+
# Split radial embeddings if provided (external weights mode)
523+
if x_edge is not None:
524+
x_edge_by_m = x_edge.split(self.edge_split_sizes, dim=1)
525+
526+
num_edges = x.shape[0]
523527
# Compute m=0 coefficients separately since they only have real values (no imaginary)
524528
x_0 = x_by_m[0].view(num_edges, -1)
525-
if self.rad_func is not None:
529+
if x_edge is not None:
526530
x_0 = x_0 * x_edge_by_m[0]
527531
x_0 = self.fc_m0(x_0)
528532

@@ -541,7 +545,7 @@ def forward(
541545
# Compute the values for the m > 0 coefficients
542546
for m in range(1, self.mmax + 1):
543547
x_m = x_by_m[m].view(num_edges, 2, -1)
544-
if self.rad_func is not None:
548+
if x_edge is not None:
545549
x_m = x_m * x_edge_by_m[m].unsqueeze(1)
546550
x_m = self.so2_m_conv[m - 1](x_m)
547551
out.extend(x_m)
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
"""
2+
Copyright (c) Meta Platforms, Inc. and affiliates.
3+
4+
This source code is licensed under the MIT license found in the
5+
LICENSE file in the root directory of this source tree.
6+
7+
Unified Radial MLP: Computes all layers' radial functions in a single
8+
batched operation.
9+
10+
Instead of running N separate RadialMLP forward passes:
11+
for layer in layers:
12+
radial_out = layer.so2_conv_1.rad_func(x_edge) # Sequential
13+
14+
We run one batched first layer, then each tail:
15+
all_radial_outs = unified_radial_mlp(x_edge) # list of [E, out]
16+
"""
17+
18+
from __future__ import annotations
19+
20+
from typing import TYPE_CHECKING
21+
22+
import torch
23+
import torch.nn as nn
24+
25+
if TYPE_CHECKING:
26+
from .radial import RadialMLP
27+
28+
__all__ = ["UnifiedRadialMLP", "create_unified_radial_mlp"]
29+
30+
# Expected structure of RadialMLP.net Sequential
31+
_EXPECTED_NET_STRUCTURE = (
32+
nn.Linear, # 0: first linear
33+
nn.LayerNorm, # 1
34+
nn.SiLU, # 2
35+
nn.Linear, # 3: second linear
36+
nn.LayerNorm, # 4
37+
nn.SiLU, # 5
38+
nn.Linear, # 6: third linear
39+
)
40+
41+
42+
def _validate_radial_mlp(mlp: RadialMLP, idx: int, reference: RadialMLP | None) -> None:
43+
"""
44+
Validate a single RadialMLP has expected structure and matches reference.
45+
46+
Args:
47+
mlp: The RadialMLP to validate.
48+
idx: Index in the list (for error messages).
49+
reference: First RadialMLP to compare dimensions against (None for first).
50+
"""
51+
# Check layer count
52+
if len(mlp.net) != 7:
53+
raise ValueError(f"RadialMLP[{idx}]: expected 7 layers, got {len(mlp.net)}")
54+
55+
# Check layer types
56+
for j, expected_type in enumerate(_EXPECTED_NET_STRUCTURE):
57+
if not isinstance(mlp.net[j], expected_type):
58+
raise TypeError(
59+
f"RadialMLP[{idx}].net[{j}]: expected {expected_type.__name__}, "
60+
f"got {type(mlp.net[j]).__name__}"
61+
)
62+
63+
# Check feature dimensions match reference (all MLPs must be identical)
64+
if reference is not None:
65+
for j in (0, 3, 6): # Linear layers
66+
if mlp.net[j].in_features != reference.net[j].in_features:
67+
raise ValueError(
68+
f"RadialMLP[{idx}].net[{j}]: in_features mismatch "
69+
f"({mlp.net[j].in_features} vs {reference.net[j].in_features})"
70+
)
71+
if mlp.net[j].out_features != reference.net[j].out_features:
72+
raise ValueError(
73+
f"RadialMLP[{idx}].net[{j}]: out_features mismatch "
74+
f"({mlp.net[j].out_features} vs {reference.net[j].out_features})"
75+
)
76+
77+
78+
class UnifiedRadialMLP(nn.Module):
79+
"""
80+
Unified radial MLP that batches the first linear layer across N RadialMLPs.
81+
82+
The first layer uses concatenated weights for a single GEMM (all N layers
83+
share the same input). Layers 2+ use stacked weight buffers for fast
84+
indexed functional calls.
85+
"""
86+
87+
def __init__(self, radial_mlps: list[RadialMLP]) -> None:
88+
"""
89+
Initialize from a list of RadialMLP modules.
90+
91+
Args:
92+
radial_mlps: List of RadialMLP modules with identical architecture.
93+
"""
94+
super().__init__()
95+
96+
assert len(radial_mlps) > 0, "Need at least one RadialMLP"
97+
98+
# Validate all MLPs have expected structure and match each other
99+
for i, mlp in enumerate(radial_mlps):
100+
_validate_radial_mlp(mlp, i, radial_mlps[0] if i > 0 else None)
101+
102+
self.num_layers = len(radial_mlps)
103+
self.hidden_features = radial_mlps[0].net[0].out_features
104+
self.ln_eps = radial_mlps[0].net[1].eps
105+
106+
# First layer: concatenated for single GEMM
107+
self.register_buffer(
108+
"W1_cat",
109+
torch.cat([mlp.net[0].weight.data for mlp in radial_mlps], dim=0),
110+
)
111+
self.register_buffer(
112+
"b1_cat",
113+
torch.cat([mlp.net[0].bias.data for mlp in radial_mlps], dim=0),
114+
)
115+
116+
# Remaining layers: stacked [N, ...] for indexed access
117+
self.register_buffer(
118+
"ln1_weight",
119+
torch.stack([mlp.net[1].weight.data for mlp in radial_mlps], dim=0),
120+
)
121+
self.register_buffer(
122+
"ln1_bias",
123+
torch.stack([mlp.net[1].bias.data for mlp in radial_mlps], dim=0),
124+
)
125+
self.register_buffer(
126+
"fc2_weight",
127+
torch.stack([mlp.net[3].weight.data for mlp in radial_mlps], dim=0),
128+
)
129+
self.register_buffer(
130+
"fc2_bias",
131+
torch.stack([mlp.net[3].bias.data for mlp in radial_mlps], dim=0),
132+
)
133+
self.register_buffer(
134+
"ln2_weight",
135+
torch.stack([mlp.net[4].weight.data for mlp in radial_mlps], dim=0),
136+
)
137+
self.register_buffer(
138+
"ln2_bias",
139+
torch.stack([mlp.net[4].bias.data for mlp in radial_mlps], dim=0),
140+
)
141+
self.register_buffer(
142+
"fc3_weight",
143+
torch.stack([mlp.net[6].weight.data for mlp in radial_mlps], dim=0),
144+
)
145+
self.register_buffer(
146+
"fc3_bias",
147+
torch.stack([mlp.net[6].bias.data for mlp in radial_mlps], dim=0),
148+
)
149+
150+
def umas_radial_mlp(self, h: torch.Tensor, i: int) -> torch.Tensor:
151+
"""Apply layers 2+ (LN -> SiLU -> Linear -> LN -> SiLU -> Linear)."""
152+
H = self.hidden_features
153+
h = torch.nn.functional.layer_norm(
154+
h, (H,), self.ln1_weight[i], self.ln1_bias[i], self.ln_eps
155+
)
156+
h = torch.nn.functional.silu(h)
157+
h = torch.nn.functional.linear(h, self.fc2_weight[i], self.fc2_bias[i])
158+
h = torch.nn.functional.layer_norm(
159+
h, (H,), self.ln2_weight[i], self.ln2_bias[i], self.ln_eps
160+
)
161+
h = torch.nn.functional.silu(h)
162+
return torch.nn.functional.linear(h, self.fc3_weight[i], self.fc3_bias[i])
163+
164+
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
165+
"""
166+
Compute all N radial outputs.
167+
168+
Args:
169+
x: Input tensor of shape [E, in_features]
170+
171+
Returns:
172+
List of N tensors, each of shape [E, out_features]
173+
"""
174+
# Single batched GEMM for first layer, then split into per-layer chunks
175+
h_all = torch.nn.functional.linear(x, self.W1_cat, self.b1_cat)
176+
h_per_layer = h_all.split(self.hidden_features, dim=1)
177+
return [self.umas_radial_mlp(h_per_layer[i], i) for i in range(self.num_layers)]
178+
179+
180+
def create_unified_radial_mlp(radial_mlps: list) -> UnifiedRadialMLP:
181+
"""
182+
Factory function to create a UnifiedRadialMLP from a list of RadialMLPs.
183+
184+
Args:
185+
radial_mlps: List of RadialMLP modules
186+
187+
Returns:
188+
UnifiedRadialMLP instance with shared first layer weights
189+
"""
190+
return UnifiedRadialMLP(radial_mlps)

0 commit comments

Comments
 (0)