-
Notifications
You must be signed in to change notification settings - Fork 611
GeoFLARE: added GALE_FA, an alternate attention to GALE, for GeoTransolver #1405
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 4 commits
d61fd66
0206a35
6a4a0bf
e278263
7e1c271
f6ee8cb
5da6f0d
1174fe3
44bcbe0
6bdd626
6dacd6c
d1059e4
2b6f33a
ff033e7
995ab9f
33f7dd3
2552481
a09d16a
8bcd6fb
a8c64df
5f769ee
b6e6c41
7c3c6fd
ca8e20d
87929af
fef93e9
92f33ba
5918232
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,253 @@ | ||||||||||||||
| # SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. | ||||||||||||||
| # SPDX-FileCopyrightText: All rights reserved. | ||||||||||||||
| # SPDX-License-Identifier: Apache-2.0 | ||||||||||||||
| # | ||||||||||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||||||||||
| # you may not use this file except in compliance with the License. | ||||||||||||||
| # You may obtain a copy of the License at | ||||||||||||||
| # | ||||||||||||||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||||||||||||||
| # | ||||||||||||||
| # Unless required by applicable law or agreed to in writing, software | ||||||||||||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||||||||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||||||||
| # See the License for the specific language governing permissions and | ||||||||||||||
| # limitations under the License. | ||||||||||||||
|
|
||||||||||||||
| """GAFLARE (Geometry-Aware FLARE) attention layer and transformer block. | ||||||||||||||
|
|
||||||||||||||
| This module provides the GAFLARE attention mechanism, | ||||||||||||||
| an alternative to the GALE attention mechanism of the GeoTransolver. | ||||||||||||||
| """ | ||||||||||||||
|
|
||||||||||||||
| from __future__ import annotations | ||||||||||||||
|
|
||||||||||||||
| import torch | ||||||||||||||
| import torch.nn as nn | ||||||||||||||
| import torch.nn.functional as F | ||||||||||||||
| from einops import rearrange | ||||||||||||||
| from jaxtyping import Float | ||||||||||||||
|
|
||||||||||||||
| import physicsnemo # noqa: F401 for docs | ||||||||||||||
| from physicsnemo.core.version_check import check_version_spec | ||||||||||||||
|
|
||||||||||||||
| # Check optional dependency availability | ||||||||||||||
| TE_AVAILABLE = check_version_spec("transformer_engine", "0.1.0", hard_fail=False) | ||||||||||||||
| if TE_AVAILABLE: | ||||||||||||||
| import transformer_engine.pytorch as te | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| class GAFLARE(nn.Module): | ||||||||||||||
coreyjadams marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||||||
| r"""GAFLARE: Geometry-Aware FLARE attention layer. | ||||||||||||||
| Adopted: | ||||||||||||||
| - FLARE attention: Fast Low-rank Attention Routing Engine | ||||||||||||||
| paper: https://arxiv.org/abs/2508.12594 | ||||||||||||||
| - GeoTransolver context: | ||||||||||||||
| paper: https://arxiv.org/abs/2512.20399 | ||||||||||||||
|
|
||||||||||||||
| GAFLARE is an alternative to the GALE attention mechanism of the GeoTransolver | ||||||||||||||
| It support cross-attention with a context vector, built from geometry and global embeddings. | ||||||||||||||
mnabian marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||||||
| GAFLARE combines self-attention on learned physical state slices with cross-attention | ||||||||||||||
| to geometry-aware context, using a learnable mixing weight to blend the two. | ||||||||||||||
|
|
||||||||||||||
| Parameters | ||||||||||||||
| ---------- | ||||||||||||||
| dim : int | ||||||||||||||
| Input dimension of the features. | ||||||||||||||
| heads : int, optional | ||||||||||||||
| Number of attention heads. Default is 8. | ||||||||||||||
| dim_head : int, optional | ||||||||||||||
| Dimension of each attention head. Default is 64. | ||||||||||||||
| dropout : float, optional | ||||||||||||||
| Dropout rate. Default is 0.0. | ||||||||||||||
| slice_num : int, optional | ||||||||||||||
| Number of learned queries. Default is 64. | ||||||||||||||
| use_te : bool, optional | ||||||||||||||
| Whether to use Transformer Engine backend when available. Default is False. | ||||||||||||||
| context_dim : int, optional | ||||||||||||||
| Dimension of the context vector for cross-attention. Default is 0. | ||||||||||||||
|
|
||||||||||||||
| Forward | ||||||||||||||
| ------- | ||||||||||||||
| x : tuple[torch.Tensor, ...] | ||||||||||||||
| Tuple of input tensors, each of shape :math:`(B, N, C)` where :math:`B` is | ||||||||||||||
| batch size, :math:`N` is number of tokens, and :math:`C` is number of channels. | ||||||||||||||
| context : tuple[torch.Tensor, ...] | None, optional | ||||||||||||||
| Context tensor for cross-attention of shape :math:`(B, H, S_c, D_c)` where | ||||||||||||||
| :math:`H` is number of heads, :math:`S_c` is number of context slices, and | ||||||||||||||
| :math:`D_c` is context dimension. If ``None``, only self-attention is applied. | ||||||||||||||
| Default is ``None``. | ||||||||||||||
|
|
||||||||||||||
| Outputs | ||||||||||||||
| ------- | ||||||||||||||
| list[torch.Tensor] | ||||||||||||||
| List of output tensors, each of shape :math:`(B, N, C)`, same shape as inputs. | ||||||||||||||
|
|
||||||||||||||
| Notes | ||||||||||||||
| ----- | ||||||||||||||
| The mixing between self-attention and cross-attention is controlled by a learnable | ||||||||||||||
| parameter ``state_mixing`` which is passed through a sigmoid function to ensure | ||||||||||||||
| the mixing weight stays in :math:`[0, 1]`. | ||||||||||||||
|
|
||||||||||||||
| See Also | ||||||||||||||
| -------- | ||||||||||||||
| :class:`GALE` : Origional GeoTransolver GALE attention class. | ||||||||||||||
| :class:`GALE_block` : Transformer block using GAFLARE attention. | ||||||||||||||
|
||||||||||||||
| :class:`GALE` : Origional GeoTransolver GALE attention class. | |
| :class:`GALE_block` : Transformer block using GAFLARE attention. | |
| :class:`GALE` : Original GeoTransolver GALE attention class. | |
| :class:`GAFLARE_block` : Transformer block using GAFLARE attention. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Origional -> Original
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Incorrect cross-reference
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GALE_block calls GAFLARE and GALE is alternative attention mechanism.
Should we change it as -
See Also
--------
:class:`GALE` : Original GeoTransolver GALE attention class.
:class:`GALE_block` : Transformer block **that calls GALE or GAFLARE** attention.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest we explicitly raise an error when use_te=True, instead of silently setting it to False.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added
if use_te:
raise ValueError((
"GAFLARE does not support Transformer Engine backend. "
"Use use_te=False; TE disables FlashAttention for differing q/k sizes in FLARE attention."
)
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use_te parameter ignored - hardcoded to False on line 121, making the use_te constructor parameter ineffective.
| self.use_te = False # te will disable FlashAttention for different size of q and k | |
| self.scale = 1. #dim_head**-0.5 | |
| super().__init__() | |
| self.use_te = use_te and TE_AVAILABLE |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
raise ValueError() added.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.scale set to 1.0 but commented code suggests it should be dim_head**-0.5 for proper attention scaling. Current implementation may affect attention quality.
| self.scale = 1. #dim_head**-0.5 | |
| self.scale = dim_head**-0.5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated: self.scale = 1. # FLARE scale is 1.0
FLARE code uses scale 1.0
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If context_dim=0, these layers are silently created, and will fail later (or produce garbage) if a non-empty context tensor is passed. A better design would skip creating the corss-attention layers entirely when context_dim=0. Or you can explicitly raise an error:
if context is not None and self.context_dim == 0:
raise ValueError(...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added if context_dim > 0: to skip creating the cross-attention layers.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pass in self.scale.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing newline at end of file
| print(outputs[0].shape) | |
| print(outputs[0].shape) |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -35,6 +35,8 @@ | |
| PhysicsAttentionIrregularMesh, | ||
| ) | ||
|
|
||
| from physicsnemo.experimental.models.geotransolver.gaflare import GAFLARE | ||
|
|
||
| # Check optional dependency availability | ||
| TE_AVAILABLE = check_version_spec("transformer_engine", "0.1.0", hard_fail=False) | ||
| if TE_AVAILABLE: | ||
|
|
@@ -369,6 +371,7 @@ def __init__( | |
| use_te: bool = True, | ||
| plus: bool = False, | ||
| context_dim: int = 0, | ||
| attention_type: str = "GALE", | ||
| ) -> None: | ||
| super().__init__() | ||
|
|
||
|
|
@@ -387,7 +390,8 @@ def __init__( | |
| self.ln_1 = nn.LayerNorm(hidden_dim) | ||
|
|
||
| # GALE attention layer | ||
| self.Attn = GALE( | ||
| if attention_type in globals(): | ||
| self.Attn = globals()[attention_type]( | ||
|
||
| hidden_dim, | ||
| heads=num_heads, | ||
| dim_head=hidden_dim // num_heads, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a newer, simpler syntax we can use here with
OptionalImport.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added following import to gaflare.py and gale.py:
from physicsnemo.core.version_check import OptionalImport
te = OptionalImport("transformer_engine.pytorch", "0.1.0")