diff --git a/CHANGELOG.md b/CHANGELOG.md index 95f0a58bb7..06e8e26c06 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -38,6 +38,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 blocks that preserve 2D and 3D rotational equivariance using a grid-based layout for efficient GPU parallelization, and an emphasis on compact `einsum` operations. +- Flare attention support for both Transolver and GeoTransolver models. ### Changed diff --git a/physicsnemo/experimental/models/geotransolver/gale.py b/physicsnemo/experimental/models/geotransolver/gale.py index 6990ece8fa..4e7286d715 100644 --- a/physicsnemo/experimental/models/geotransolver/gale.py +++ b/physicsnemo/experimental/models/geotransolver/gale.py @@ -29,16 +29,17 @@ from jaxtyping import Float import physicsnemo # noqa: F401 for docs -from physicsnemo.core.version_check import check_version_spec +from physicsnemo.core.version_check import check_version_spec, OptionalImport from physicsnemo.nn import Mlp from physicsnemo.nn.module.physics_attention import ( PhysicsAttentionIrregularMesh, ) +from physicsnemo.experimental.models.geotransolver.gale_fa import GALE_FA + # Check optional dependency availability TE_AVAILABLE = check_version_spec("transformer_engine", "0.1.0", hard_fail=False) -if TE_AVAILABLE: - import transformer_engine.pytorch as te +te = OptionalImport("transformer_engine.pytorch", "0.1.0") class GALE(PhysicsAttentionIrregularMesh): @@ -317,6 +318,9 @@ class GALE_block(nn.Module): Whether to use Transolver++ features. Default is ``False``. context_dim : int, optional Dimension of the context vector for cross-attention. Default is 0. + attention_type : str, optional + attention_type is used to choose the attention type (GALE or GALE_FA). + Default is ``"GALE"``. Forward ------- @@ -369,6 +373,7 @@ def __init__( use_te: bool = True, plus: bool = False, context_dim: int = 0, + attention_type: str = "GALE", ) -> None: super().__init__() @@ -386,17 +391,34 @@ def __init__( else: self.ln_1 = nn.LayerNorm(hidden_dim) - # GALE attention layer - self.Attn = GALE( - hidden_dim, - heads=num_heads, - dim_head=hidden_dim // num_heads, - dropout=dropout, - slice_num=slice_num, - use_te=use_te, - plus=plus, - context_dim=context_dim, - ) + # Attention layer + match attention_type: + case 'GALE': + self.Attn = GALE( + hidden_dim, + heads=num_heads, + dim_head=hidden_dim // num_heads, + dropout=dropout, + slice_num=slice_num, + use_te=use_te, + plus=plus, + context_dim=context_dim, + ) + case 'GALE_FA': + self.Attn = GALE_FA( + hidden_dim, + heads=num_heads, + dim_head=hidden_dim // num_heads, + dropout=dropout, + n_global_queries=slice_num, + use_te=use_te, + context_dim=context_dim, + ) + case _: + raise ValueError( + f"Invalid attention type: {attention_type}. " + f"Expected 'GALE' or 'GALE_FA'." + ) # Feed-forward network with layer normalization if use_te: diff --git a/physicsnemo/experimental/models/geotransolver/gale_fa.py b/physicsnemo/experimental/models/geotransolver/gale_fa.py new file mode 100644 index 0000000000..d6f2536788 --- /dev/null +++ b/physicsnemo/experimental/models/geotransolver/gale_fa.py @@ -0,0 +1,251 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GALE_FA (Geometry-Aware Latent Embeddings with FLARE self-Attention) attention layer. + +This module provides the GALE_FA attention mechanism, +an alternative to the GALE attention mechanism of the GeoTransolver. +""" + +from __future__ import annotations + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from jaxtyping import Float + +from physicsnemo.core.version_check import check_version_spec, OptionalImport + +# Check optional dependency availability +TE_AVAILABLE = check_version_spec("transformer_engine", "0.1.0", hard_fail=False) +te = OptionalImport("transformer_engine.pytorch", "0.1.0") + + +class GALE_FA(nn.Module): + r"""GALE_FA: Geometry-Aware Latent Embeddings with FLARE self-Attention attention layer. + Adopted: + - FLARE attention: Fast Low-rank Attention Routing Engine + paper: https://arxiv.org/abs/2508.12594 + - GeoTransolver context: + paper: https://arxiv.org/abs/2512.20399 + + GALE_FA is an alternative to the GALE attention mechanism of the GeoTransolver + It supports cross-attention with a context vector, built from geometry and global embeddings. + GALE_FA combines FLARE self-attention on learned physical state slices with cross-attention + to geometry-aware context, using a learnable mixing weight to blend the two. + + Parameters + ---------- + dim : int + Input dimension of the features. + heads : int, optional + Number of attention heads. Default is 8. + dim_head : int, optional + Dimension of each attention head. Default is 64. + dropout : float, optional + Dropout rate. Default is 0.0. + n_global_queries : int, optional + Number of learned global queries. Default is 64. + use_te : bool, optional + Whether to use Transformer Engine backend when available. Default is False. + context_dim : int, optional + Dimension of the context vector for cross-attention. Default is 0. + + Forward + ------- + x : tuple[torch.Tensor, ...] + Tuple of input tensors, each of shape :math:`(B, N, C)` where :math:`B` is + batch size, :math:`N` is number of tokens, and :math:`C` is number of channels. + context : tuple[torch.Tensor, ...] | None, optional + Context tensor for cross-attention of shape :math:`(B, H, S_c, D_c)` where + :math:`H` is number of heads, :math:`S_c` is number of context slices, and + :math:`D_c` is context dimension. If ``None``, only self-attention is applied. + Default is ``None``. + + Outputs + ------- + list[torch.Tensor] + List of output tensors, each of shape :math:`(B, N, C)`, same shape as inputs. + + Notes + ----- + The mixing between self-attention and cross-attention is controlled by a learnable + parameter ``state_mixing`` which is passed through a sigmoid function to ensure + the mixing weight stays in :math:`[0, 1]`. + + See Also + -------- + :class:`GALE` : Original GeoTransolver GALE attention class. + :class:`GALE_block` : Transformer block that calls GALE or GALE_FA attention. + + Examples + -------- + >>> import torch + >>> gale_fa = GALE_FA(dim=256, heads=8, dim_head=32, context_dim=32) + >>> x = (torch.randn(2, 100, 256),) # Single input tensor in tuple + >>> context = torch.randn(2, 8, 64, 32) # Context for cross-attention + >>> outputs = gale_fa(x, context) + >>> len(outputs) + 1 + >>> outputs[0].shape + torch.Size([2, 100, 256]) + """ + + def __init__( + self, + dim, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + n_global_queries: int = 64, + use_te: bool = True, + context_dim: int = 0, + ): + if use_te: + raise ValueError( + "GALE_FA does not support Transformer Engine backend. " + "Use use_te=False; TE disables FlashAttention for differing q/k sizes in FLARE attention." + ) + super().__init__() + self.use_te = use_te + self.heads = heads + self.dim_head = dim_head + self.scale = 1.0 + # It is recommended by the FLARE authors to use self.scale = 1 if self.dim_head <= 8 else (self.dim_head ** -0.5) + # but we use self.scale = 1.0 because the recommended scaling is not tested yet. + inner_dim = dim_head * heads + + linear_layer = te.Linear if self.use_te else nn.Linear + + # Global queries for FLARE self-attention + self.q_global = nn.Parameter(torch.randn(1, heads, n_global_queries, dim_head)) + + # Linear projections for self-attention + self.in_project_x = linear_layer(dim, inner_dim) + self.self_k = linear_layer(dim_head, dim_head) + self.self_v = linear_layer(dim_head, dim_head) + + if context_dim > 0: + # Linear projections for cross-attention + self.cross_q = linear_layer(dim_head, dim_head) + self.cross_k = linear_layer(context_dim, dim_head) + self.cross_v = linear_layer(context_dim, dim_head) + + # Learnable mixing weight between self and cross attention + self.state_mixing = nn.Parameter(torch.tensor(0.0)) + + # te attention + if self.use_te: + self.attn_fn = te.DotProductAttention( + num_attention_heads=self.heads, + kv_channels=self.dim_head, + attention_dropout=dropout, + qkv_format="bshd", + softmax_scale=self.scale + ) + + # Linear projection for output + self.out_linear = linear_layer(inner_dim, dim) + self.out_dropout = nn.Dropout(dropout) + + + def forward( + self, + x: tuple[Float[torch.Tensor, "batch tokens channels"], ...], + context: Float[torch.Tensor, "batch heads context_slices context_dim"] + | None = None, + ) -> list[Float[torch.Tensor, "batch tokens channels"]]: + r"""Forward pass of the GALE_FA module. + + Applies GALE_FA attention to the input features. + + Parameters + ---------- + x : tuple[torch.Tensor, ...] + Tuple of input tensors, each of shape :math:`(B, N, C)` where :math:`B` + is batch size, :math:`N` is number of tokens, and :math:`C` is number + of channels. + context : torch.Tensor | None, optional + Context tensor for cross-attention of shape :math:`(B, H, S_c, D_c)` + where :math:`H` is number of heads, :math:`S_c` is number of context + slices, and :math:`D_c` is context dimension. If ``None``, only + self-attention is applied. Default is ``None``. + + Returns + ------- + list[torch.Tensor] + List of output tensors, each of shape :math:`(B, N, C)``, same shape + as inputs. + """ + + # with record_function("forward"): + x_mid = [self.in_project_x(_x) for _x in x] + x_mid = [rearrange( + _x_mid, "B N (h d) -> B N h d", h=self.heads, d=self.dim_head + ) for _x_mid in x_mid] + x_mid = [_x_mid.permute(0, 2, 1, 3) for _x_mid in x_mid] # [B, H, N, D] + G = [self.q_global.to(dtype=x_mid[0].dtype).expand(x_mid[0].shape[0], -1, -1, -1)] * len(x) + k = [self.self_k(_x_mid) for _x_mid in x_mid] + v = [self.self_v(_x_mid) for _x_mid in x_mid] + + # FLARE: Self Attention + if self.use_te: + # Transformer Engine expects (B, S, H, D) format + G = [rearrange(_G, "b h s d -> b s h d") for _G in G] + k = [rearrange(_k, "b h s d -> b s h d") for _k in k] + v = [rearrange(_v, "b h s d -> b s h d") for _v in v] + z = [self.attn_fn(_G, _k, _v) for _G, _k, _v in zip(G, k, v)] + z = [rearrange( + _z, "b s (h d) -> b s h d", h=self.heads, d=self.dim_head + ) for _z in z] + self_attention = [self.attn_fn(_k, _G, _z) for _k, _G, _z in zip(k, G, z)] + self_attention = [rearrange( + _self_attention, "b s (h d) -> b h s d", h=self.heads, d=self.dim_head + ) for _self_attention in self_attention] + else: + # Use PyTorch's scaled dot-product attention + z = [F.scaled_dot_product_attention(_G, _k, _v, scale=self.scale) for _G, _k, _v in zip(G, k, v)] + self_attention = [F.scaled_dot_product_attention(_k, _G, _z, scale=self.scale) for _k, _G, _z in zip(k, G, z)] + + # apply cross-attention with physical states: + if context is not None: + q = [self.cross_q(_x_mid) for _x_mid in x_mid] + k = self.cross_k(context) + v = self.cross_v(context) + + if self.use_te: + q = [rearrange(_q, "b h s d -> b s h d") for _q in q] + k = rearrange(k, "b h s d -> b s h d") + v = rearrange(v, "b h s d -> b s h d") + cross_attention = [self.attn_fn(_q, k, v) for _q in q] + cross_attention = [rearrange( + _cross_attention, "b s (h d) -> b h s d", h=self.heads, d=self.dim_head + ) for _cross_attention in cross_attention] + else: + cross_attention = [F.scaled_dot_product_attention(_q, k, v, scale=self.scale) for _q in q] + + # Apply learnable mixing: + mixing_weight = torch.sigmoid(self.state_mixing) + outputs = [mixing_weight * _ys + (1 - mixing_weight) * _yc for _ys, _yc in zip(self_attention, cross_attention)] + else: + outputs = self_attention + + outputs = [_y.permute(0, 2, 1, 3) for _y in outputs] # [B, N, H, D] + outputs = [rearrange(_out, "b n h d -> b n (h d)") for _out in outputs] + outputs = [self.out_linear(_out) for _out in outputs] + return [self.out_dropout(_out) for _out in outputs] + diff --git a/physicsnemo/experimental/models/geotransolver/geotransolver.py b/physicsnemo/experimental/models/geotransolver/geotransolver.py index 5b106810df..3a49abdd21 100644 --- a/physicsnemo/experimental/models/geotransolver/geotransolver.py +++ b/physicsnemo/experimental/models/geotransolver/geotransolver.py @@ -204,6 +204,9 @@ class GeoTransolver(Module): Neighbors in radius for the local features. Default is ``[8, 32]``. n_hidden_local : int, optional Hidden dimension for the local features. Default is 32. + attention_type : str, optional + attention_type is used to choose the attention type (GALE or GALE_FA). + Default is ``"GALE"``. Forward ------- @@ -315,6 +318,7 @@ def __init__( radii: list[float] | None = None, neighbors_in_radius: list[int] | None = None, n_hidden_local: int = 32, + attention_type: str = "GALE", ) -> None: super().__init__(meta=GeoTransolverMetaData()) self.__name__ = "GeoTransolver" @@ -404,6 +408,7 @@ def __init__( use_te=use_te, plus=plus, context_dim=context_dim, + attention_type=attention_type, ) for layer_idx in range(n_layers) ] diff --git a/physicsnemo/experimental/models/transolver/__init__.py b/physicsnemo/experimental/models/transolver/__init__.py new file mode 100644 index 0000000000..61f4f1e147 --- /dev/null +++ b/physicsnemo/experimental/models/transolver/__init__.py @@ -0,0 +1,25 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Experimental Transolver with configurable attention type. + +This module provides Transolver with ``attention_type`` to select between +physics attention and FLARE. +""" + +from .transolver import Transolver + +__all__ = ["Transolver"] diff --git a/physicsnemo/experimental/models/transolver/transolver.py b/physicsnemo/experimental/models/transolver/transolver.py new file mode 100644 index 0000000000..ef9cb82288 --- /dev/null +++ b/physicsnemo/experimental/models/transolver/transolver.py @@ -0,0 +1,231 @@ + +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Experimental Transolver with configurable attention type. + +Experimental Transolver that inherits from core Transolver and supports +multiple attention backends via ``attention_type``: physics attention or FLARE. +""" + +from __future__ import annotations + +import torch +import torch.nn as nn +from jaxtyping import Float + +from physicsnemo.models.transolver import Transolver as CoreTransolver +from physicsnemo.models.transolver.transolver import _TransolverMlp + +from physicsnemo.experimental.nn import FLARE + + +class _FLAREBlock(nn.Module): + r"""Transformer block with FLARE attention instead of physics attention. + + Mirrors TransolverBlock structure but uses FLARE for the attention layer. + FLARE does not support Transformer Engine. + """ + + def __init__( + self, + num_heads: int, + hidden_dim: int, + dropout: float, + act: str = "gelu", + mlp_ratio: int = 4, + last_layer: bool = False, + out_dim: int = 1, + n_global_queries: int = 32, + ) -> None: + super().__init__() + self.last_layer = last_layer + dim_head = hidden_dim // num_heads + + self.ln_1 = nn.LayerNorm(hidden_dim) + self.Attn = FLARE( + dim=hidden_dim, + heads=num_heads, + dim_head=dim_head, + dropout=dropout, + n_global_queries=n_global_queries, + use_te=False, + ) + self.ln_mlp1 = nn.Sequential( + nn.LayerNorm(hidden_dim), + _TransolverMlp( + in_features=hidden_dim, + hidden_features=hidden_dim * mlp_ratio, + out_features=hidden_dim, + act_layer=act, + use_te=False, + ), + ) + if last_layer: + self.ln_mlp2 = nn.Sequential( + nn.LayerNorm(hidden_dim), + nn.Linear(hidden_dim, out_dim), + ) + + def forward( + self, fx: Float[torch.Tensor, "B N C"] + ) -> Float[torch.Tensor, "B N C_out"]: + fx = self.Attn(self.ln_1(fx)) + fx + fx = self.ln_mlp1(fx) + fx + if self.last_layer: + return self.ln_mlp2(fx) + return fx + + +class Transolver(CoreTransolver): + r"""Experimental Transolver with configurable attention type. + + Inherits from the core Transolver and adds ``attention_type`` to select + between physics attention (default) and FLARE attention. When + ``attention_type="flare"``, FLARE is used and Transformer Engine is + disabled (use_te forced to False). + + Parameters + ---------- + functional_dim : int + Dimension of input values, not including embeddings. + out_dim : int + Dimension of model output. + embedding_dim : int | None, optional + Dimension of input embeddings. Required if ``unified_pos=False``. + n_layers : int, optional + Number of transformer blocks. Default is 4. + n_hidden : int, optional + Hidden dimension. Default is 256. + dropout : float, optional + Dropout rate. Default is 0.0. + n_head : int, optional + Number of attention heads. Default is 8. + act : str, optional + Activation function name. Default is ``"gelu"``. + mlp_ratio : int, optional + MLP hidden ratio. Default is 4. + slice_num : int, optional + Number of physics slices (physics) or global queries (flare). + Default is 32. + unified_pos : bool, optional + Whether to use unified positional embeddings. Default is ``False``. + ref : int, optional + Reference grid size for unified position. Default is 8. + structured_shape : None | tuple[int, ...], optional + Shape of structured data. ``None`` for unstructured. Default is ``None``. + use_te : bool, optional + Whether to use Transformer Engine. Ignored when ``attention_type="flare"``. + Default is ``True``. + time_input : bool, optional + Whether to include time embeddings. Default is ``False``. + plus : bool, optional + Whether to use Transolver++ variant (physics only). Default is ``False``. + attention_type : str, optional + Attention backend: ``"physics"`` (default) or ``"flare"``. + + Forward + ------- + Same as :class:`~physicsnemo.models.transolver.Transolver`. + + Outputs + ------- + Same as :class:`~physicsnemo.models.transolver.Transolver`. + + See Also + -------- + :class:`~physicsnemo.models.transolver.Transolver` : Core Transolver model. + :class:`~physicsnemo.experimental.nn.FLARE` : FLARE attention layer. + """ + + def __init__( + self, + functional_dim: int, + out_dim: int, + embedding_dim: int | None = None, + n_layers: int = 4, + n_hidden: int = 256, + dropout: float = 0.0, + n_head: int = 8, + act: str = "gelu", + mlp_ratio: int = 4, + slice_num: int = 32, + unified_pos: bool = False, + ref: int = 8, + structured_shape: None | tuple[int, ...] = None, + use_te: bool = True, + time_input: bool = False, + plus: bool = False, + attention_type: str = "physics", + ) -> None: + if attention_type not in ("physics", "flare"): + raise ValueError( + f"attention_type must be 'physics' or 'flare', got {attention_type!r}" + ) + + # FLARE does not support TE + effective_use_te = use_te if attention_type == "physics" else False + if attention_type == "flare" and use_te: + import warnings + + from physicsnemo.core.warnings import ExperimentalFeatureWarning + + warnings.warn( + "attention_type='flare' requires use_te=False; Transformer Engine " + "is incompatible with FLARE. Forcing use_te=False.", + ExperimentalFeatureWarning, + stacklevel=2, + ) + + super().__init__( + functional_dim=functional_dim, + out_dim=out_dim, + embedding_dim=embedding_dim, + n_layers=n_layers, + n_hidden=n_hidden, + dropout=dropout, + n_head=n_head, + act=act, + mlp_ratio=mlp_ratio, + slice_num=slice_num, + unified_pos=unified_pos, + ref=ref, + structured_shape=structured_shape, + use_te=effective_use_te, + time_input=time_input, + plus=plus, + ) + + self.attention_type = attention_type + + if attention_type == "flare": + # Replace physics attention blocks with FLARE blocks + self.blocks = nn.ModuleList( + [ + _FLAREBlock( + num_heads=n_head, + hidden_dim=n_hidden, + dropout=dropout, + act=act, + mlp_ratio=mlp_ratio, + last_layer=(i == n_layers - 1), + out_dim=out_dim, + n_global_queries=slice_num, + ) + for i in range(n_layers) + ] + ) + self.initialize_weights() diff --git a/physicsnemo/experimental/nn/__init__.py b/physicsnemo/experimental/nn/__init__.py index 6a1305a710..0e9cc95dee 100644 --- a/physicsnemo/experimental/nn/__init__.py +++ b/physicsnemo/experimental/nn/__init__.py @@ -20,3 +20,7 @@ that are under active development. These components may have breaking API changes between releases. """ + +from .flare_attention import FLARE + +__all__ = ["FLARE"] diff --git a/physicsnemo/experimental/nn/flare_attention.py b/physicsnemo/experimental/nn/flare_attention.py new file mode 100644 index 0000000000..70a6c14f57 --- /dev/null +++ b/physicsnemo/experimental/nn/flare_attention.py @@ -0,0 +1,171 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""FLARE (Fast Low-rank Attention Routing Engine) attention layer. + +This module provides the FLARE attention mechanism, +an alternative to the PhysicsAttention attention mechanism of the Transolver. +""" + +from __future__ import annotations + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from jaxtyping import Float + +from physicsnemo.core.version_check import check_version_spec, OptionalImport + +# Check optional dependency availability +TE_AVAILABLE = check_version_spec("transformer_engine", "0.1.0", hard_fail=False) +te = OptionalImport("transformer_engine.pytorch", "0.1.0") + + +class FLARE(nn.Module): + r"""FLARE: Fast Low-rank Attention Routing Engine attention layer. + Adopted: + - FLARE attention: Fast Low-rank Attention Routing Engine + paper: https://arxiv.org/abs/2508.12594 + + Parameters + ---------- + dim : int + Input dimension of the features. + heads : int, optional + Number of attention heads. Default is 8. + dim_head : int, optional + Dimension of each attention head. Default is 64. + dropout : float, optional + Dropout rate. Default is 0.0. + n_global_queries : int, optional + Number of learned global queries. Default is 64. + use_te : bool, optional + Whether to use Transformer Engine backend when available. Default is False. + + Forward + ------- + x : torch.Tensor[Batch, N_points, N_Channels] ([B, N, C]) + Outputs + ------- + torch.Tensor[Batch, N_points, N_Channels] ([B, N, C]) + + Examples + -------- + >>> import torch + >>> flare = FLARE(dim=256, heads=8, dim_head=32) + >>> x = torch.randn(2, 100, 256) + >>> outputs = flare(x) + >>> outputs.shape + torch.Size([2, 100, 256]) + """ + + def __init__( + self, + dim, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + n_global_queries: int = 64, + use_te: bool = True, + ): + if use_te: + raise ValueError( + "FLARE does not support Transformer Engine backend. " + "Use use_te=False; TE disables FlashAttention for differing q/k sizes in FLARE attention." + ) + super().__init__() + self.use_te = use_te + self.heads = heads + self.dim_head = dim_head + self.scale = 1.0 + # It is recommended by the FLARE authors to use self.scale = 1 if self.dim_head <= 8 else (self.dim_head ** -0.5) + # but we use self.scale = 1.0 because the recommended scaling is not tested yet. + inner_dim = dim_head * heads + + linear_layer = te.Linear if self.use_te else nn.Linear + + # Global queries for FLARE self-attention + self.q_global = nn.Parameter(torch.randn(1, heads, n_global_queries, dim_head)) + + # Linear projections for self-attention + self.in_project_x = linear_layer(dim, inner_dim) + self.self_k = linear_layer(dim_head, dim_head) + self.self_v = linear_layer(dim_head, dim_head) + + # te attention + if self.use_te: + self.attn_fn = te.DotProductAttention( + num_attention_heads=self.heads, + kv_channels=self.dim_head, + attention_dropout=dropout, + qkv_format="bshd", + softmax_scale=self.scale + ) + + # Linear projection for output + self.out_linear = linear_layer(inner_dim, dim) + self.out_dropout = nn.Dropout(dropout) + + def forward(self, x: Float[torch.Tensor, "B N C"]) -> Float[torch.Tensor, "B N C"]: + r"""Forward pass of the FLARE module. + + Applies FLARE attention to the input features. + + Parameters + ---------- + x : torch.Tensor[Batch, N_points, N_Channels] ([B, N, C]) + Input tensor of shape :math:`(B, N, C)` where :math:`B` is batch size, + :math:`N` is number of points, and :math:`C` is number of channels. + + Returns + ------- + torch.Tensor[Batch, N_points, N_Channels] ([B, N, C]) + Output tensor of shape :math:`(B, N, C)`, same shape as inputs. + """ + + x_mid = self.in_project_x(x) + x_mid = rearrange( + x_mid, "B N (h d) -> B N h d", h=self.heads, d=self.dim_head + ) + x_mid = x_mid.permute(0, 2, 1, 3) # [B, H, N, D] + G = self.q_global.to(dtype=x_mid.dtype).expand(x_mid.shape[0], -1, -1, -1) + k = self.self_k(x_mid) + v = self.self_v(x_mid) + + # FLARE: Fast Low-rank Attention Routing Engine + if self.use_te: + # Transformer Engine expects (B, S, H, D) format + G = rearrange(G, "b h s d -> b s h d") + k = rearrange(k, "b h s d -> b s h d") + v = rearrange(v, "b h s d -> b s h d") + z = self.attn_fn(G, k, v) + z = rearrange( + z, "b s (h d) -> b s h d", h=self.heads, d=self.dim_head + ) + self_attention = self.attn_fn(k, G, z) + y = rearrange( + self_attention, "b s (h d) -> b h s d", h=self.heads, d=self.dim_head + ) + else: + # Use PyTorch's scaled dot-product attention + z = F.scaled_dot_product_attention(G, k, v, scale=self.scale) + y = F.scaled_dot_product_attention(k, G, z, scale=self.scale) + + out_x = y.permute(0, 2, 1, 3) # [B, N, H, D] + out_x = rearrange(out_x, "b n h d -> b n (h d)") + out_x = self.out_linear(out_x) + return self.out_dropout(out_x) diff --git a/test/experimental/models/__init__.py b/test/experimental/models/__init__.py new file mode 100644 index 0000000000..af85283aa4 --- /dev/null +++ b/test/experimental/models/__init__.py @@ -0,0 +1,15 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/test/experimental/models/transolver/data/transolver2d_flare_output.pth b/test/experimental/models/transolver/data/transolver2d_flare_output.pth new file mode 100644 index 0000000000..107fbc5779 Binary files /dev/null and b/test/experimental/models/transolver/data/transolver2d_flare_output.pth differ diff --git a/test/experimental/models/transolver/data/transolver_irregular_flare_output.pth b/test/experimental/models/transolver/data/transolver_irregular_flare_output.pth new file mode 100644 index 0000000000..90fcb6f7fa Binary files /dev/null and b/test/experimental/models/transolver/data/transolver_irregular_flare_output.pth differ diff --git a/test/experimental/models/transolver/test_experimental_transolver_with_flare_attention.py b/test/experimental/models/transolver/test_experimental_transolver_with_flare_attention.py new file mode 100644 index 0000000000..7ec9de4e37 --- /dev/null +++ b/test/experimental/models/transolver/test_experimental_transolver_with_flare_attention.py @@ -0,0 +1,409 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random + +import pytest +import torch + +from physicsnemo.core.module import Module +from physicsnemo.experimental.models.transolver import Transolver +from test.common import ( + check_ort_version, + validate_amp, + validate_checkpoint, + validate_combo_optims, + validate_cuda_graphs, + validate_forward_accuracy, + validate_jit, + validate_onnx_export, + validate_onnx_runtime, +) +from test.conftest import requires_module + + +@pytest.mark.parametrize("attention_type", ["physics", "flare"]) +@pytest.mark.parametrize( + "config", + ["default_structured", "custom_irregular"], + ids=["with_defaults_structured", "with_custom_irregular"], +) +def test_transolver_constructor(attention_type, config): + """Test Transolver model constructor and attributes per MOD-008a.""" + if config == "default_structured": + # Test with structured 2D data and default parameters + model = Transolver( + functional_dim=3, + out_dim=1, + structured_shape=(64, 64), + unified_pos=True, + use_te=False, + attention_type=attention_type, + ) + # Verify default attribute values + assert model.n_hidden == 256, "Default n_hidden should be 256" + assert model.time_input is False, "Default time_input should be False" + assert model.unified_pos is True + assert model.structured_shape == (64, 64) + assert model.embedding_dim == 64 # ref * ref = 8 * 8 = 64 + assert len(model.blocks) == 4, "Default n_layers should be 4" + else: + # Test with irregular mesh data and custom parameters + model = Transolver( + functional_dim=2, + out_dim=4, + embedding_dim=3, + n_layers=8, + n_hidden=64, + dropout=0.1, + n_head=4, + act="gelu", + mlp_ratio=2, + slice_num=16, + unified_pos=False, + structured_shape=None, + use_te=False, + time_input=True, + plus=True, + attention_type=attention_type, + ) + # Verify custom attribute values + assert model.n_hidden == 64 + assert model.time_input is True + assert model.unified_pos is False + assert model.structured_shape is None + assert model.embedding_dim == 3 + assert len(model.blocks) == 8 + + # Common assertions for all configurations + assert isinstance(model, Module), ( + "Transolver should inherit from physicsnemo.Module" + ) + assert hasattr(model, "preprocess"), "Model should have preprocess MLP" + assert hasattr(model, "blocks"), "Model should have transformer blocks" + assert hasattr(model, "meta"), "Model should have metadata" + + +@pytest.mark.parametrize("attention_type", ["physics", "flare"]) +def test_transolver2d_forward(device, attention_type): + """Test Transolver2D forward pass""" + torch.manual_seed(0) + # Construct Transolver model + file_name = ( + "models/transolver/data/transolver2d_output.pth" + if attention_type == "physics" + else "experimental/models/transolver/data/transolver2d_flare_output.pth" + ) + model = Transolver( + structured_shape=(85, 85), + n_layers=8, + n_hidden=64, + dropout=0, + n_head=4, + time_input=False, + act="gelu", + mlp_ratio=1, + functional_dim=1, + out_dim=1, + slice_num=32, + ref=1, + unified_pos=True, + use_te=False, + attention_type=attention_type, + ).to(device) + + bsize = 4 + + fx = torch.randn(bsize, 85 * 85, 1).to(device) + embedding = torch.randn(bsize, 85, 85).to(device) + + assert validate_forward_accuracy( + model, + ( + fx, + embedding, + ), + file_name=file_name, + atol=2e-3, + ) + + +@pytest.mark.parametrize("attention_type", ["physics", "flare"]) +def test_transolver_irregular_forward(device, attention_type): + """Test Transolver Irregular forward pass""" + torch.manual_seed(0) + # Construct Transolver model + file_name = ( + "models/transolver/data/transolver_irregular_output.pth" + if attention_type == "physics" + else "experimental/models/transolver/data/transolver_irregular_flare_output.pth" + ) + model = Transolver( + structured_shape=None, + n_layers=8, + n_hidden=64, + dropout=0, + n_head=4, + time_input=False, + act="gelu", + mlp_ratio=1, + functional_dim=2, + embedding_dim=3, + out_dim=1, + slice_num=32, + ref=1, + unified_pos=False, + use_te=False, + attention_type=attention_type, + ).to(device) + + bsize = 4 + + embedding = torch.randn(bsize, 12345, 3).to(device) + functional_input = torch.randn(bsize, 12345, 2).to(device) + + assert validate_forward_accuracy( + model, + ( + embedding, + functional_input, + ), + file_name=file_name, + atol=1e-3, + ) + + +@pytest.mark.parametrize("attention_type", ["physics", "flare"]) +def test_transolver_optims(device, attention_type): + """Test transolver optimizations""" + + def setup_model(): + """Setups up fresh transolver model and inputs for each optim test""" + + model = Transolver( + structured_shape=None, + n_layers=8, + n_hidden=64, + dropout=0, + n_head=4, + time_input=False, + act="gelu", + mlp_ratio=1, + functional_dim=2, + embedding_dim=3, + out_dim=1, + slice_num=32, + ref=1, + unified_pos=False, + use_te=False, + attention_type=attention_type, + ).to(device) + + if device == "cuda:0": + bsize = 4 + n_points = 12345 + else: + bsize = 1 + n_points = 123 + + embedding = torch.randn(bsize, n_points, 3).to(device) + functional_input = torch.randn(bsize, n_points, 2).to(device) + + return model, embedding, functional_input + + # Ideally always check graphs first + model, pos, invar = setup_model() + assert validate_cuda_graphs( + model, + ( + pos, + invar, + ), + ) + + # Check JIT + model, pos, invar = setup_model() + assert validate_jit( + model, + ( + pos, + invar, + ), + ) + # Check AMP + model, pos, invar = setup_model() + assert validate_amp( + model, + ( + pos, + invar, + ), + ) + # Check Combo + model, pos, invar = setup_model() + assert validate_combo_optims( + model, + ( + pos, + invar, + ), + ) + + +@requires_module("transformer_engine") +@pytest.mark.parametrize("attention_type", ["physics"]) +def test_transolver_te(pytestconfig, attention_type): + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + + torch.manual_seed(0) + + model = Transolver( + structured_shape=None, + n_layers=8, + n_hidden=64, + dropout=0, + n_head=4, + time_input=False, + act="gelu", + mlp_ratio=1, + functional_dim=2, + embedding_dim=3, + out_dim=1, + slice_num=32, + ref=1, + unified_pos=False, + use_te=True, + attention_type=attention_type, + ).to("cuda") + + bsize = 4 + + embedding = torch.randn(bsize, 12345, 3).to("cuda") + functional_input = torch.randn(bsize, 12345, 2).to("cuda") + + assert validate_forward_accuracy( + model, + ( + embedding, + functional_input, + ), + file_name="models/transolver/data/transolver_irregular_te_output.pth", + atol=1e-3, + ) + + +@pytest.mark.parametrize("attention_type", ["physics", "flare"]) +def test_transolver_checkpoint(device, attention_type): + """Test transolver checkpoint save/load""" + # Construct transolver models + model_1 = Transolver( + structured_shape=None, + n_layers=8, + n_hidden=64, + dropout=0, + n_head=4, + time_input=False, + act="gelu", + mlp_ratio=1, + functional_dim=2, + embedding_dim=3, + out_dim=1, + slice_num=32, + ref=1, + unified_pos=False, + use_te=False, + attention_type=attention_type, + ).to(device) + + model_2 = Transolver( + structured_shape=None, + n_layers=8, + n_hidden=64, + dropout=0, + n_head=4, + time_input=False, + act="gelu", + mlp_ratio=1, + functional_dim=2, + embedding_dim=3, + out_dim=1, + slice_num=32, + ref=1, + unified_pos=False, + use_te=False, + attention_type=attention_type, + ).to(device) + + bsize = random.randint(1, 2) + + embedding = torch.randn(bsize, 12345, 3).to(device) + functional_input = torch.randn(bsize, 12345, 2).to(device) + + assert validate_checkpoint( + model_1, + model_2, + ( + functional_input, + embedding, + ), + ) + + +@check_ort_version() +@pytest.mark.parametrize("attention_type", ["physics", "flare"]) +def test_transolver_deploy(device, attention_type): + """Test transolver deployment support""" + # Construct transolver model + model = Transolver( + structured_shape=(85, 85), + n_layers=8, + n_hidden=64, + dropout=0, + n_head=4, + time_input=False, + act="gelu", + mlp_ratio=1, + functional_dim=1, + out_dim=1, + slice_num=32, + ref=1, + unified_pos=True, + use_te=False, + attention_type=attention_type, + ).to(device) + + bsize = 4 + + pos = torch.randn(bsize, 85 * 85, 1).to(device) + invar = torch.randn(bsize, 85, 85).to(device) + + assert validate_onnx_export( + model, + ( + pos, + invar, + ), + ) + assert validate_onnx_runtime( + model, + ( + invar, + invar, + ), + 1e-2, + 1e-2, + ) diff --git a/test/experimental/nn/test_flare_attention.py b/test/experimental/nn/test_flare_attention.py new file mode 100644 index 0000000000..6ae2a5ae74 --- /dev/null +++ b/test/experimental/nn/test_flare_attention.py @@ -0,0 +1,65 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for FLARE attention layer.""" + +import pytest +import torch + +from physicsnemo.experimental.nn import FLARE + + +def test_flare_forward(device): + """Test FLARE forward pass and output shape.""" + torch.manual_seed(42) + flare = FLARE(dim=64, heads=4, dim_head=16, n_global_queries=32, use_te=False).to( + device + ) + x = torch.randn(2, 100, 64).to(device) + out = flare(x) + assert out.shape == (2, 100, 64) + assert not torch.isnan(out).any() + + +@pytest.mark.parametrize("heads,dim_head", [(2, 32), (8, 8), (4, 16)]) +def test_flare_configs(device, heads, dim_head): + """Test FLARE with different head configurations.""" + torch.manual_seed(42) + dim = heads * dim_head + flare = FLARE( + dim=dim, heads=heads, dim_head=dim_head, n_global_queries=16, use_te=False + ).to(device) + x = torch.randn(2, 50, dim).to(device) + out = flare(x) + assert out.shape == x.shape + + +def test_flare_use_te_raises(): + """Test that use_te=True raises ValueError.""" + with pytest.raises(ValueError, match="does not support Transformer Engine"): + FLARE(dim=64, heads=4, dim_head=16, use_te=True) + + +def test_flare_gradient_flow(device): + """Test gradient flow through FLARE.""" + torch.manual_seed(42) + flare = FLARE(dim=32, heads=4, dim_head=8, use_te=False).to(device) + x = torch.randn(2, 20, 32, device=device, requires_grad=True) + out = flare(x) + loss = out.sum() + loss.backward() + assert x.grad is not None + assert not torch.isnan(x.grad).any() diff --git a/test/models/geotransolver/test_gale.py b/test/models/geotransolver/test_gale.py index 9167b1f39e..af334f7447 100644 --- a/test/models/geotransolver/test_gale.py +++ b/test/models/geotransolver/test_gale.py @@ -14,12 +14,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest import torch from physicsnemo.experimental.models.geotransolver.gale import ( GALE, GALE_block, ) +from physicsnemo.experimental.models.geotransolver.gale_fa import GALE_FA # ============================================================================= # GALE (Geometry-Aware Latent Embeddings) Attention Tests @@ -128,13 +130,118 @@ def test_gale_forward_multiple_inputs(device): assert not torch.isnan(outputs[1]).any() +# ============================================================================= +# GALE_FA Attention Tests +# ============================================================================= + + +def test_gale_fa_forward_basic(device): + """Test GALE_FA attention layer pass without context.""" + torch.manual_seed(42) + + dim = 64 + heads = 4 + dim_head = 16 + n_global_queries = 8 + batch_size = 2 + n_tokens = 100 + + gale_fa = GALE_FA( + dim=dim, + heads=heads, + dim_head=dim_head, + dropout=0.0, + n_global_queries=n_global_queries, + use_te=False, + context_dim=dim_head, # Must match dim_head for cross attention + ).to(device) + + # Single input tensor wrapped in tuple + x = torch.randn(batch_size, n_tokens, dim).to(device) + + outputs = gale_fa((x,), context=None) + + assert len(outputs) == 1 + assert outputs[0].shape == (batch_size, n_tokens, dim) + assert not torch.isnan(outputs[0]).any() + + +def test_gale_fa_forward_with_context(device): + """Test GALE_FA attention layer with cross-attention context.""" + torch.manual_seed(42) + + dim = 64 + heads = 4 + dim_head = 16 + n_global_queries = 8 + batch_size = 2 + n_tokens = 100 + context_tokens = 32 + context_dim = dim_head + + gale_fa = GALE_FA( + dim=dim, + heads=heads, + dim_head=dim_head, + dropout=0.0, + n_global_queries=n_global_queries, + use_te=False, + context_dim=context_dim, + ).to(device) + + x = torch.randn(batch_size, n_tokens, dim).to(device) + context = torch.randn(batch_size, heads, context_tokens, context_dim).to(device) + + outputs = gale_fa((x,), context=context) + + assert len(outputs) == 1 + assert outputs[0].shape == (batch_size, n_tokens, dim) + assert not torch.isnan(outputs[0]).any() + + +def test_gale_fa_forward_multiple_inputs(device): + """Test GALE_FA attention layer with multiple input tensors.""" + torch.manual_seed(42) + + dim = 64 + heads = 4 + dim_head = 16 + n_global_queries = 8 + batch_size = 2 + n_tokens_1 = 100 + n_tokens_2 = 150 + context_dim = dim_head + + gale_fa = GALE_FA( + dim=dim, + heads=heads, + dim_head=dim_head, + dropout=0.0, + n_global_queries=n_global_queries, + use_te=False, + context_dim=context_dim, + ).to(device) + + x1 = torch.randn(batch_size, n_tokens_1, dim).to(device) + x2 = torch.randn(batch_size, n_tokens_2, dim).to(device) + + outputs = gale_fa((x1, x2), context=None) + + assert len(outputs) == 2 + assert outputs[0].shape == (batch_size, n_tokens_1, dim) + assert outputs[1].shape == (batch_size, n_tokens_2, dim) + assert not torch.isnan(outputs[0]).any() + assert not torch.isnan(outputs[1]).any() + + # ============================================================================= # GALE_block Tests # ============================================================================= -def test_gale_block_forward(device): - """Test GALE_block transformer block forward pass.""" +@pytest.mark.parametrize("attention_type", ["GALE", "GALE_FA"]) +def test_gale_block_forward(device, attention_type): + """Test GALE_block transformer block forward pass (GALE and GALE_FA).""" torch.manual_seed(42) hidden_dim = 64 @@ -156,6 +263,7 @@ def test_gale_block_forward(device): use_te=False, plus=False, context_dim=context_dim, + attention_type=attention_type, ).to(device) x = torch.randn(batch_size, n_tokens, hidden_dim).to(device) @@ -168,8 +276,9 @@ def test_gale_block_forward(device): assert not torch.isnan(outputs[0]).any() -def test_gale_block_multiple_inputs(device): - """Test GALE_block with multiple input tensors.""" +@pytest.mark.parametrize("attention_type", ["GALE", "GALE_FA"]) +def test_gale_block_multiple_inputs(device, attention_type): + """Test GALE_block with multiple input tensors and attention type (GALE and GALE_FA).""" torch.manual_seed(42) hidden_dim = 64 @@ -192,6 +301,7 @@ def test_gale_block_multiple_inputs(device): use_te=False, plus=False, context_dim=context_dim, + attention_type=attention_type, ).to(device) x1 = torch.randn(batch_size, n_tokens_1, hidden_dim).to(device) diff --git a/test/models/geotransolver/test_geotransolver.py b/test/models/geotransolver/test_geotransolver.py index c0a8d33968..124827a208 100644 --- a/test/models/geotransolver/test_geotransolver.py +++ b/test/models/geotransolver/test_geotransolver.py @@ -36,9 +36,10 @@ # ============================================================================= +@pytest.mark.parametrize("attention_type", ["GALE", "GALE_FA"]) @pytest.mark.parametrize("use_geometry", [False, True]) @pytest.mark.parametrize("use_global", [False, True]) -def test_geotransolver_forward(device, use_geometry, use_global): +def test_geotransolver_forward(device, attention_type, use_geometry, use_global): """Test GeoTransolver model forward pass with optional geometry and global context.""" torch.manual_seed(42) @@ -65,6 +66,7 @@ def test_geotransolver_forward(device, use_geometry, use_global): time_input=False, plus=False, include_local_features=False, + attention_type=attention_type, ).to(device) local_emb = torch.randn(batch_size, n_tokens, 32).to(device)