Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 16 additions & 11 deletions src/fairchem/core/models/uma/escn_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,8 @@ def forward(self, data_dict: AtomicData) -> dict[str, torch.Tensor]:
# Initialize node embeddings
###############################################################

sys_node_embedding = csd_mixed_emb[data_dict["batch"]]

# Init per node representations using an atomic number based embedding
with record_function("atom embedding"):
x_message = torch.zeros(
Expand All @@ -746,10 +748,9 @@ def forward(self, data_dict: AtomicData) -> dict[str, torch.Tensor]:
device=data_dict["pos"].device,
dtype=data_dict["pos"].dtype,
)
x_message[:, 0, :] = self.sphere_embedding(data_dict["atomic_numbers"])

sys_node_embedding = csd_mixed_emb[data_dict["batch"]]
x_message[:, 0, :] = x_message[:, 0, :] + sys_node_embedding
x_message[:, 0, :] = (
self.sphere_embedding(data_dict["atomic_numbers"]) + sys_node_embedding
)

###
# Hook to allow MOLE
Expand Down Expand Up @@ -782,9 +783,19 @@ def forward(self, data_dict: AtomicData) -> dict[str, torch.Tensor]:
# Pre-fuse envelope into wigner_inv
wigner_inv_envelope = wigner_inv * edge_envelope

# Get all radial embeddings: edge_degree + layer radials
# General backend: returns [x_edge] * (1 + N) - rad_func computed internally
# Fast backends: returns precomputed [edge_radial, layer_0_radial, ...]
all_radial_embeddings = self.backend.get_unified_radial_emb(x_edge, self)
edge_degree_input = all_radial_embeddings[0]
x_edge_per_layer = all_radial_embeddings[1:]

# Apply edge_degree_embedding
# General backend: rad_func computed internally
# Fast backends: rad_func=None, uses precomputed radial from edge_degree_input
x_message = self.edge_degree_embedding(
x_message,
x_edge,
edge_degree_input,
graph_dict["edge_index"],
wigner_inv_envelope,
data_dict["gp_node_offset"],
Expand All @@ -794,12 +805,6 @@ def forward(self, data_dict: AtomicData) -> dict[str, torch.Tensor]:
# Update spherical node embeddings
###############################################################

# Get edge embeddings for each layer
# General backend: raw x_edge (rad_func computed inside SO2_Convolution)
# Fast backends: precomputed radials
with record_function("layer_radial_emb"):
x_edge_per_layer = self.backend.get_layer_radial_emb(x_edge, self)

for i in range(self.num_layers):
with record_function(f"message passing {i}"):
x_message = self.blocks[i](
Expand Down
34 changes: 18 additions & 16 deletions src/fairchem/core/models/uma/nn/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ def forward_chunk(
wigner_inv_envelope,
node_offset=0,
):
radial = self.rad_func(x_edge)
# rad_func is None when radials are precomputed by UnifiedRadialMLP
radial = self.rad_func(x_edge) if self.rad_func is not None else x_edge

return self.backend.edge_degree_scatter(
x,
Expand Down Expand Up @@ -150,13 +151,6 @@ def __init__(
self.embedding_target = embedding_target
assert embedding_size % 2 == 0, f"{embedding_size=} must be even"

if self.embedding_target == "charge":
# 100 is a conservative upper bound
self.target_dict = {str(x): x + 100 for x in range(-100, 101)}
elif self.embedding_target == "spin":
# 100 is a conservative upper bound
self.target_dict = {str(x): x for x in range(101)}

if self.embedding_type == "pos_emb":
# dividing by 2 because x_proj multiplies by 2
if not grad:
Expand All @@ -173,7 +167,16 @@ def __init__(
for param in self.lin_emb.parameters():
param.requires_grad = False
elif self.embedding_type == "rand_emb":
self.rand_emb = nn.Embedding(len(self.target_dict), embedding_size)
# Embedding table sizes and index offset for tensor-based computation.
# Charge: x in [-100, 100] -> idx = x + 100 in [0, 200], so 201 embeddings
# Spin: x in [0, 100] -> idx = x in [0, 100], so 101 embeddings
if self.embedding_target == "charge":
self.idx_offset = 100
num_embeddings = 201
else: # spin
self.idx_offset = 0
num_embeddings = 101
self.rand_emb = nn.Embedding(num_embeddings, embedding_size)
if not grad:
for param in self.rand_emb.parameters():
param.requires_grad = False
Expand All @@ -199,13 +202,12 @@ def forward(self, x):
x[x == 0] = -100
return self.lin_emb(x.unsqueeze(-1).float())
elif self.embedding_type == "rand_emb":
return self.rand_emb(
torch.tensor(
[self.target_dict[str(i)] for i in x.tolist()],
device=x.device,
dtype=torch.long,
)
)
# Convert charge/spin values to embedding indices via tensor arithmetic.
# This avoids the graph break caused by x.tolist() and dict lookup.
# For charge: x in [-100, 100] -> idx = x + 100 in [0, 200]
# For spin: x in [0, 100] -> idx = x in [0, 100]
idx = (x + self.idx_offset).long()
return self.rand_emb(idx)
raise ValueError(f"embedding type {self.embedding_type} not implemented")


Expand Down
48 changes: 30 additions & 18 deletions src/fairchem/core/models/uma/nn/execution_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,26 +93,29 @@ def prepare_model_for_inference(model: torch.nn.Module) -> None:
"""

@staticmethod
def get_layer_radial_emb(
def get_unified_radial_emb(
x_edge: torch.Tensor,
model: torch.nn.Module,
) -> list[torch.Tensor]:
"""
Get edge embeddings for each layer.
Get all radial embeddings: edge_degree + layer radials.

Default implementation returns the same raw x_edge for all layers.
SO2_Convolution will compute rad_func(x_edge) internally.
Default implementation returns [x_edge] * (1 + N).
x_edge is passed to edge_degree_embedding and layers, which
compute rad_func(x_edge) internally.

Override in fast backends to precompute radials.
Override in fast backends to precompute all radials in one GEMM.

Args:
x_edge: Edge embeddings [E, edge_features]
model: The backbone model

Returns:
List of edge embeddings, one per layer
List [edge_degree_input, layer_0_input, ..., layer_N-1_input]
For general backend: all are x_edge (raw edge embeddings).
For fast backends: all are precomputed radial outputs.
"""
return [x_edge] * len(model.blocks)
return [x_edge] * (1 + len(model.blocks))

@staticmethod
def prepare_wigner(
Expand Down Expand Up @@ -247,7 +250,7 @@ def edge_degree_scatter(

# Slice wigner to m=0 columns and rotate:
# [E, L, m0] @ [E, m0, C] -> [E, L, C]
wigner_inv_m0 = wigner_inv[:, :, :m_0_num_coefficients]
wigner_inv_m0 = wigner_inv[:, :, :m_0_num_coefficients] / rescale_factor
x_edge_embedding = torch.bmm(wigner_inv_m0, radial)

# Type cast if needed
Expand All @@ -257,7 +260,7 @@ def edge_degree_scatter(
return x.index_add(
0,
edge_index[1] - node_offset,
x_edge_embedding / rescale_factor,
x_edge_embedding,
)


Expand Down Expand Up @@ -292,7 +295,8 @@ def prepare_model_for_inference(model: torch.nn.Module) -> None:
Replaces so2_conv_1 with SO2_Conv1_WithRadialBlock and
so2_conv_2 with SO2_Conv2_InternalBlock in each block's
Edgewise module. Then creates a UnifiedRadialMLP from all
radial functions for efficient batched computation.
radial functions (edge_degree + layer rad_funcs) for efficient
batched computation.
"""
from fairchem.core.models.uma.nn.so2_layers import (
convert_so2_conv1,
Expand All @@ -303,24 +307,32 @@ def prepare_model_for_inference(model: torch.nn.Module) -> None:
block.edge_wise.so2_conv_1 = convert_so2_conv1(block.edge_wise.so2_conv_1)
block.edge_wise.so2_conv_2 = convert_so2_conv2(block.edge_wise.so2_conv_2)

# Create unified radial MLP for batched computation
rad_funcs = [block.edge_wise.so2_conv_1.rad_func for block in model.blocks]
model._unified_radial_mlp = UnifiedRadialMLP(rad_funcs)
# Create unified radial MLP: edge_degree + layer rad_funcs in one GEMM
edge_degree_rad_func = model.edge_degree_embedding.rad_func
layer_rad_funcs = [
block.edge_wise.so2_conv_1.rad_func for block in model.blocks
]
model._unified_radial_mlp = UnifiedRadialMLP(
edge_degree_rad_func, layer_rad_funcs
)

# Null out rad_func so forward_chunk knows radials are precomputed
model.edge_degree_embedding.rad_func = None

@staticmethod
def get_layer_radial_emb(
def get_unified_radial_emb(
x_edge: torch.Tensor,
model: torch.nn.Module,
) -> list[torch.Tensor]:
"""
Compute radial embeddings for all layers using batched UnifiedRadialMLP.
Compute all radial embeddings using batched UnifiedRadialMLP.

Args:
x_edge: Edge embeddings [E, edge_features]
model: The backbone model with _unified_radial_mlp

Returns:
List of radial embeddings, one per layer [E, radial_features]
List [edge_degree_radial, layer_0_radial, ..., layer_N-1_radial]
"""
return model._unified_radial_mlp(x_edge)

Expand Down Expand Up @@ -410,15 +422,15 @@ def edge_degree_scatter(
radial = radial_output.reshape(-1, m_0_num_coefficients, sphere_channels)

# Select m=0 columns from L-ordered wigner_inv
wigner_inv_m0 = wigner_inv[:, :, _M0_COL_INDICES_L_ORDER]
wigner_inv_m0 = wigner_inv[:, :, _M0_COL_INDICES_L_ORDER] / rescale_factor
x_edge_embedding = torch.bmm(wigner_inv_m0, radial)

x_edge_embedding = x_edge_embedding.to(x.dtype)

return x.index_add(
0,
edge_index[1] - node_offset,
x_edge_embedding / rescale_factor,
x_edge_embedding,
)


Expand Down
Loading
Loading