Skip to content

Allow umas_fast_gpu Triton kernels with mmax=1 (eSEN-style models) #1996

@matr1x-1

Description

@matr1x-1

What would you like to report?

Summary

The Triton-accelerated UMAS_FAST_GPU execution backend is currently hard-gated on:

lmax == 2 and mmax == 2

Models trained with mmax == 1 cannot use these kernels and silently fall back to the general backend.

This matters because mmax == 1 is a deliberate modeling choice in some smaller/faster SO(2)-based models, including eSEN-style configurations. These models therefore miss the documented ~30–40% speedup from the fast GPU backend.

Location

In:

fairchem/core/models/uma/nn/execution_backends.py

inside UMASFastGPUBackend.validate:

if lmax != 2 or mmax != 2:
    raise ValueError("umas_fast_gpu requires lmax==2 and mmax==2")

Separately, update_inference_settings_for_fast_gpu appears to catch this ValueError and silently leaves the model on the slower general backend.

Why this matters

We are running long molecular dynamics trajectories with a custom-trained escn_md model using:

lmax = 2
mmax = 1

Because the system composition varies during the simulation, merge_mole=True / the "turbo" preset is not usable for this workload.

As a result, we lose both:

  1. the MoLE/turbo path, and
  2. the Triton fast GPU kernels.

On a single NVIDIA GH200, this leaves us at roughly:

~3 ns/day for a 216-atom system

For campaigns measured in hundreds of nanoseconds, this is the difference between feasible and borderline.

Minimal reproduction

import warp as wp

if not hasattr(wp, "vec"):
    wp.vec = wp.types.vector  # workaround for wp.vec removal in recent warp

from fairchem.core.calculate.ase_calculator import FAIRChemCalculator
from fairchem.core.units.mlip_unit.api.inference import InferenceSettings

settings = InferenceSettings(
    tf32=True,
    activation_checkpointing=False,
    merge_mole=False,              # required: composition varies
    compile=True,
    internal_graph_gen_version=3,
    execution_mode="umas_fast_gpu", # raises
)

calc = FAIRChemCalculator.from_model_checkpoint(
    "<path/to/escn_md_mmax1_checkpoint.pt>",
    task_name="<your_task>",
    device="cuda",
    inference_settings=settings,
)

This raises:

ValueError: umas_fast_gpu requires lmax==2 and mmax==2

To reproduce, use any escn_md checkpoint trained with mmax=1.

The relevant backbone configuration is logged in:

canonical_config.yaml

under:

backbone:
  lmax: 2
  mmax: 1

Request

Would it be possible to support mmax == 1 in umas_fast_gpu?

There seem to be two possible paths:

Option 1: Relax the validation gate

If the existing Triton kernels already work for mmax=1 with smaller block strides, allow:

mmax in {1, 2}

in UMASFastGPUBackend.validate.

The relevant kernels appear to include:

node_to_edge_wigner_permute
permute_wigner_inv_edge_to_node
edge_degree_scatter

For l <= 2, the mmax=1 case should have fewer m-channels per node than the mmax=2 case.

Option 2: Add an mmax=1 fast path

If the current Triton kernels assume mmax=2 shapes, a sibling kernel path for mmax=1 would be useful.

The eSEN line of models commonly uses this style of configuration, so this would likely benefit more than just this single custom model.

Expected behavior

A model with:

lmax = 2
mmax = 1

should either:

  1. use the umas_fast_gpu backend when requested, or
  2. fail with a clear explanation that mmax=1 is not supported and why.

Actual behavior

The model cannot use umas_fast_gpu because of the hard gate:

umas_fast_gpu requires lmax==2 and mmax==2

In some paths, this failure is caught and the model silently remains on the slower general backend.

Environment

fairchem-core: 2.19.0
torch: 2.8.0+cu129
GPU: NVIDIA GH200 120GB
CUDA runtime: 13.1
NVIDIA driver: 590.48.01
Architecture: linux aarch64
Model: custom escn_md backbone, lmax=2, mmax=1, no MoLE

Additional context

A minimal answer such as “not planned, because the current kernels fundamentally assume mmax=2” would still be helpful. It would let us decide whether to invest in a downstream PyTorch-only fast path or retrain the model with mmax=2.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions