Skip to content

Commit 7c12230

Browse files
committed
feat: unify edge_degree + layer radial MLPs into single batched GEMM
UnifiedRadialMLP consolidates edge_degree_embedding.rad_func and all layer rad_funcs into a single first-layer GEMM, reducing kernel launches and improving GPU utilization. Key changes: - UnifiedRadialMLP: batches first linear layer, processes tails separately - get_unified_radial_emb: returns [edge_degree_out, layer_0_out, ...] - rad_func=None sentinel: signals precomputed radials in EdgeDegreeEmbedding - Fast backends (UMASFastPytorchBackend, UMASFastGPUBackend) create and use UnifiedRadialMLP at prepare_model_for_inference time Also includes torch.compile compatibility fixes: - ChgSpinEmbedding: replaced dict lookup with tensor arithmetic - balance_channels: minor cleanup for compile compatibility Performance: 17.4 QPS on 2000 atoms (H200), forces match baseline.
1 parent 331203d commit 7c12230

5 files changed

Lines changed: 355 additions & 201 deletions

File tree

src/fairchem/core/models/uma/escn_md.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -727,6 +727,8 @@ def forward(self, data_dict: AtomicData) -> dict[str, torch.Tensor]:
727727
# Initialize node embeddings
728728
###############################################################
729729

730+
sys_node_embedding = csd_mixed_emb[data_dict["batch"]]
731+
730732
# Init per node representations using an atomic number based embedding
731733
with record_function("atom embedding"):
732734
x_message = torch.zeros(
@@ -736,10 +738,9 @@ def forward(self, data_dict: AtomicData) -> dict[str, torch.Tensor]:
736738
device=data_dict["pos"].device,
737739
dtype=data_dict["pos"].dtype,
738740
)
739-
x_message[:, 0, :] = self.sphere_embedding(data_dict["atomic_numbers"])
740-
741-
sys_node_embedding = csd_mixed_emb[data_dict["batch"]]
742-
x_message[:, 0, :] = x_message[:, 0, :] + sys_node_embedding
741+
x_message[:, 0, :] = (
742+
self.sphere_embedding(data_dict["atomic_numbers"]) + sys_node_embedding
743+
)
743744

744745
###
745746
# Hook to allow MOLE
@@ -772,9 +773,19 @@ def forward(self, data_dict: AtomicData) -> dict[str, torch.Tensor]:
772773
# Pre-fuse envelope into wigner_inv
773774
wigner_inv_envelope = wigner_inv * edge_envelope
774775

776+
# Get all radial embeddings: edge_degree + layer radials
777+
# General backend: returns [x_edge] * (1 + N) - rad_func computed internally
778+
# Fast backends: returns precomputed [edge_radial, layer_0_radial, ...]
779+
all_radial_embeddings = self.backend.get_unified_radial_emb(x_edge, self)
780+
edge_degree_input = all_radial_embeddings[0]
781+
x_edge_per_layer = all_radial_embeddings[1:]
782+
783+
# Apply edge_degree_embedding
784+
# General backend: rad_func computed internally
785+
# Fast backends: rad_func=None, uses precomputed radial from edge_degree_input
775786
x_message = self.edge_degree_embedding(
776787
x_message,
777-
x_edge,
788+
edge_degree_input,
778789
graph_dict["edge_index"],
779790
wigner_inv_envelope,
780791
data_dict["gp_node_offset"],
@@ -784,12 +795,6 @@ def forward(self, data_dict: AtomicData) -> dict[str, torch.Tensor]:
784795
# Update spherical node embeddings
785796
###############################################################
786797

787-
# Get edge embeddings for each layer
788-
# General backend: raw x_edge (rad_func computed inside SO2_Convolution)
789-
# Fast backends: precomputed radials
790-
with record_function("layer_radial_emb"):
791-
x_edge_per_layer = self.backend.get_layer_radial_emb(x_edge, self)
792-
793798
for i in range(self.num_layers):
794799
with record_function(f"message passing {i}"):
795800
x_message = self.blocks[i](

src/fairchem/core/models/uma/nn/embedding.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,8 @@ def forward_chunk(
8282
wigner_inv_envelope,
8383
node_offset=0,
8484
):
85-
radial = self.rad_func(x_edge)
85+
# rad_func is None when radials are precomputed by UnifiedRadialMLP
86+
radial = self.rad_func(x_edge) if self.rad_func is not None else x_edge
8687

8788
return self.backend.edge_degree_scatter(
8889
x,
@@ -150,13 +151,6 @@ def __init__(
150151
self.embedding_target = embedding_target
151152
assert embedding_size % 2 == 0, f"{embedding_size=} must be even"
152153

153-
if self.embedding_target == "charge":
154-
# 100 is a conservative upper bound
155-
self.target_dict = {str(x): x + 100 for x in range(-100, 101)}
156-
elif self.embedding_target == "spin":
157-
# 100 is a conservative upper bound
158-
self.target_dict = {str(x): x for x in range(101)}
159-
160154
if self.embedding_type == "pos_emb":
161155
# dividing by 2 because x_proj multiplies by 2
162156
if not grad:
@@ -173,7 +167,16 @@ def __init__(
173167
for param in self.lin_emb.parameters():
174168
param.requires_grad = False
175169
elif self.embedding_type == "rand_emb":
176-
self.rand_emb = nn.Embedding(len(self.target_dict), embedding_size)
170+
# Embedding table sizes and index offset for tensor-based computation.
171+
# Charge: x in [-100, 100] -> idx = x + 100 in [0, 200], so 201 embeddings
172+
# Spin: x in [0, 100] -> idx = x in [0, 100], so 101 embeddings
173+
if self.embedding_target == "charge":
174+
self.idx_offset = 100
175+
num_embeddings = 201
176+
else: # spin
177+
self.idx_offset = 0
178+
num_embeddings = 101
179+
self.rand_emb = nn.Embedding(num_embeddings, embedding_size)
177180
if not grad:
178181
for param in self.rand_emb.parameters():
179182
param.requires_grad = False
@@ -199,13 +202,12 @@ def forward(self, x):
199202
x[x == 0] = -100
200203
return self.lin_emb(x.unsqueeze(-1).float())
201204
elif self.embedding_type == "rand_emb":
202-
return self.rand_emb(
203-
torch.tensor(
204-
[self.target_dict[str(i)] for i in x.tolist()],
205-
device=x.device,
206-
dtype=torch.long,
207-
)
208-
)
205+
# Convert charge/spin values to embedding indices via tensor arithmetic.
206+
# This avoids the graph break caused by x.tolist() and dict lookup.
207+
# For charge: x in [-100, 100] -> idx = x + 100 in [0, 200]
208+
# For spin: x in [0, 100] -> idx = x in [0, 100]
209+
idx = (x + self.idx_offset).long()
210+
return self.rand_emb(idx)
209211
raise ValueError(f"embedding type {self.embedding_type} not implemented")
210212

211213

src/fairchem/core/models/uma/nn/execution_backends.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -88,26 +88,29 @@ def prepare_model_for_inference(model: torch.nn.Module) -> None:
8888
"""
8989

9090
@staticmethod
91-
def get_layer_radial_emb(
91+
def get_unified_radial_emb(
9292
x_edge: torch.Tensor,
9393
model: torch.nn.Module,
9494
) -> list[torch.Tensor]:
9595
"""
96-
Get edge embeddings for each layer.
96+
Get all radial embeddings: edge_degree + layer radials.
9797
98-
Default implementation returns the same raw x_edge for all layers.
99-
SO2_Convolution will compute rad_func(x_edge) internally.
98+
Default implementation returns [x_edge] * (1 + N).
99+
x_edge is passed to edge_degree_embedding and layers, which
100+
compute rad_func(x_edge) internally.
100101
101-
Override in fast backends to precompute radials.
102+
Override in fast backends to precompute all radials in one GEMM.
102103
103104
Args:
104105
x_edge: Edge embeddings [E, edge_features]
105106
model: The backbone model
106107
107108
Returns:
108-
List of edge embeddings, one per layer
109+
List [edge_degree_input, layer_0_input, ..., layer_N-1_input]
110+
For general backend: all are x_edge (raw edge embeddings).
111+
For fast backends: all are precomputed radial outputs.
109112
"""
110-
return [x_edge] * len(model.blocks)
113+
return [x_edge] * (1 + len(model.blocks))
111114

112115
@staticmethod
113116
def prepare_wigner(
@@ -242,7 +245,7 @@ def edge_degree_scatter(
242245

243246
# Slice wigner to m=0 columns and rotate:
244247
# [E, L, m0] @ [E, m0, C] -> [E, L, C]
245-
wigner_inv_m0 = wigner_inv[:, :, :m_0_num_coefficients]
248+
wigner_inv_m0 = wigner_inv[:, :, :m_0_num_coefficients] / rescale_factor
246249
x_edge_embedding = torch.bmm(wigner_inv_m0, radial)
247250

248251
# Type cast if needed
@@ -252,7 +255,7 @@ def edge_degree_scatter(
252255
return x.index_add(
253256
0,
254257
edge_index[1] - node_offset,
255-
x_edge_embedding / rescale_factor,
258+
x_edge_embedding,
256259
)
257260

258261

@@ -291,7 +294,8 @@ def prepare_model_for_inference(model: torch.nn.Module) -> None:
291294
Replaces so2_conv_1 with SO2_Conv1_WithRadialBlock and
292295
so2_conv_2 with SO2_Conv2_InternalBlock in each block's
293296
Edgewise module. Then creates a UnifiedRadialMLP from all
294-
radial functions for efficient batched computation.
297+
radial functions (edge_degree + layer rad_funcs) for efficient
298+
batched computation.
295299
"""
296300
from fairchem.core.models.uma.nn.so2_layers import (
297301
convert_so2_conv1,
@@ -302,24 +306,32 @@ def prepare_model_for_inference(model: torch.nn.Module) -> None:
302306
block.edge_wise.so2_conv_1 = convert_so2_conv1(block.edge_wise.so2_conv_1)
303307
block.edge_wise.so2_conv_2 = convert_so2_conv2(block.edge_wise.so2_conv_2)
304308

305-
# Create unified radial MLP for batched computation
306-
rad_funcs = [block.edge_wise.so2_conv_1.rad_func for block in model.blocks]
307-
model._unified_radial_mlp = UnifiedRadialMLP(rad_funcs)
309+
# Create unified radial MLP: edge_degree + layer rad_funcs in one GEMM
310+
edge_degree_rad_func = model.edge_degree_embedding.rad_func
311+
layer_rad_funcs = [
312+
block.edge_wise.so2_conv_1.rad_func for block in model.blocks
313+
]
314+
model._unified_radial_mlp = UnifiedRadialMLP(
315+
edge_degree_rad_func, layer_rad_funcs
316+
)
317+
318+
# Null out rad_func so forward_chunk knows radials are precomputed
319+
model.edge_degree_embedding.rad_func = None
308320

309321
@staticmethod
310-
def get_layer_radial_emb(
322+
def get_unified_radial_emb(
311323
x_edge: torch.Tensor,
312324
model: torch.nn.Module,
313325
) -> list[torch.Tensor]:
314326
"""
315-
Compute radial embeddings for all layers using batched UnifiedRadialMLP.
327+
Compute all radial embeddings using batched UnifiedRadialMLP.
316328
317329
Args:
318330
x_edge: Edge embeddings [E, edge_features]
319331
model: The backbone model with _unified_radial_mlp
320332
321333
Returns:
322-
List of radial embeddings, one per layer [E, radial_features]
334+
List [edge_degree_radial, layer_0_radial, ..., layer_N-1_radial]
323335
"""
324336
return model._unified_radial_mlp(x_edge)
325337

@@ -408,15 +420,15 @@ def edge_degree_scatter(
408420
radial = radial_output.reshape(-1, m_0_num_coefficients, sphere_channels)
409421

410422
# Select m=0 columns from L-ordered wigner_inv
411-
wigner_inv_m0 = wigner_inv[:, :, _M0_COL_INDICES_L_ORDER]
423+
wigner_inv_m0 = wigner_inv[:, :, _M0_COL_INDICES_L_ORDER] / rescale_factor
412424
x_edge_embedding = torch.bmm(wigner_inv_m0, radial)
413425

414426
x_edge_embedding = x_edge_embedding.to(x.dtype)
415427

416428
return x.index_add(
417429
0,
418430
edge_index[1] - node_offset,
419-
x_edge_embedding / rescale_factor,
431+
x_edge_embedding,
420432
)
421433

422434

0 commit comments

Comments
 (0)