Skip to content

Add AeroJEPA model + SuperWing tutorial recipe (experimental)#1690

Open
fgiral000 wants to merge 74 commits into
NVIDIA:mainfrom
fgiral000:aerojepa-integration
Open

Add AeroJEPA model + SuperWing tutorial recipe (experimental)#1690
fgiral000 wants to merge 74 commits into
NVIDIA:mainfrom
fgiral000:aerojepa-integration

Conversation

@fgiral000

Copy link
Copy Markdown

PhysicsNeMo Pull Request

Description

Adds the AeroJEPA model and a SuperWing tutorial recipe under
physicsnemo.experimental and examples/cfd/external_aerodynamics/.
AeroJEPA is a Joint-Embedding Predictive Architecture for 3D
aerodynamic surrogate modeling: instead of mapping geometry directly to
a flow field, it predicts a latent representation of the flow from a
latent representation of the geometry and operating conditions, and
reconstructs the field through a continuous implicit decoder when
needed (Giral et al., arXiv:2605.05586).

What this PR delivers:

  • Model at physicsnemo.experimental.models.aerojepa.
    AeroJEPA composes a context encoder, a target encoder, a query-token
    field decoder (collectively AeroJEPATrunk), and a JEPA predictor
    head (PrototypeTokenJEPAHead) into a single
    physicsnemo.core.module.Module. The training path takes context
    positions/features, independent target encoder surface/volume inputs,
    and operating conditions; the predictor predicts target tokens, and
    the decoder evaluates the field at user-supplied query points.
    predict is a no-grad inference wrapper; decode_field_chunked
    supports memory-bounded evaluation over very large query sets.
    Concrete encoders (ContextTransformer, TargetTransformer,
    PointTransformer), the QueryTokenDecoder, and the encoder ABCs
    are all exposed as composable components.
  • Building blocks at
    physicsnemo.experimental.models.aerojepa.layers. TokenSet and
    EncoderOutput token dataclasses, a deterministic
    FourierPositionalEncoding, ResidualMLP, the
    LocalPointTransformerBlock / LocalTokenCrossAttentionBlock
    attention blocks (with optional AdaLN / AdaLN-Zero conditioning), the
    PointCloudTokenizer (seven center-selection strategies with k-NN
    cluster pooling), token batching / mask / k-NN helpers, and prototype
    anchor build / load utilities. TokenSet and EncoderOutput are
    re-exported from the model package for convenience.
  • Losses at physicsnemo.experimental.models.aerojepa.losses.
    SIGReg and TokenLatentSIGReg (a sketch isotropic-Gaussian
    regularizer for latent-token distributions, with a padding-aware
    wrapper), the flatten_valid_token_features /
    reshape_token_features_for_sigreg masking helpers, and the
    reconstruction loss family (MSELoss / RelativeL2Loss /
    RelativeMSELoss / RelativeL2MSELoss, each with functional and
    nn.Module forms, optional per-channel weights stored as a
    persistent buffer, optional per-point weights, and an optional
    validity mask).
  • Tutorial recipe at
    examples/cfd/external_aerodynamics/aerojepa. End-to-end Hydra-driven
    workflow on the public SuperWing dataset (Yang et al.,
    arXiv:2512.14397): dataset download via the Hugging Face Hub
    (yunplus/SuperWing), automatic split-by-geometry manifest and
    per-channel normalization stats, JEPA training (reconstruction +
    latent + SIGReg with linear warmups; AdamW +
    warmup-cosine; optional EMA), checkpointed inference with chunked
    decoding, three-panel GT | Pred | |Error| field plots for the three
    surface channels (Cp, Cf_tau, Cf_z), per-channel relative-L2 /
    RMSE / MAE metrics on the test split, and a pressure-only CL/CD
    post-processor that integrates the surface field and emits a per-case
    CSV plus a parity scatter.

Checklist

Tests

  • 193 unit tests under test/experimental/models/aerojepa/
    (constructor + attribute checks, non-regression shape checks on the
    encoders, decoder, predictor, trunk, top-level model, layers, and
    losses). pytest test/experimental/models/aerojepa/ -q passes
    locally on CPU (~20 s).
  • Full SuperWing end-to-end smoke-tested on a single GPU:
    train.py -> inference.py -> superwing_metrics -> superwing_forces.
    Training losses decrease monotonically; inference produces field
    plots, per-case field-error metrics, and a force-coefficient parity
    scatter.

Dependencies

No new core dependencies. The example recipe adds optional
example-side dependencies in
examples/cfd/external_aerodynamics/aerojepa/requirements.txt
(Hugging Face Hub for the dataset download, plotting and
post-processing utilities). Pre-commit hooks, ruff, interrogate,
markdownlint, and the SPDX license check pass on every file in the
PR.

@copy-pr-bot

copy-pr-bot Bot commented Jun 1, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@greptile-apps

greptile-apps Bot commented Jun 1, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR introduces the AeroJEPA model — a Joint-Embedding Predictive Architecture for 3D aerodynamic surrogate modeling — along with all its building blocks (encoders, decoder, predictor, layers, losses) under physicsnemo.experimental, plus a full Hydra-driven SuperWing tutorial recipe and 193 unit tests.

  • Model core (aerojepa.py, trunk.py, predictor.py, decoder.py): The context/target encoder–predictor–decoder pipeline is well-structured; batched and single-sample forward paths handle edge cases correctly.
  • Training recipe (train.py, runtime.py): The validation path in _run_epoch builds full autograd graphs without a torch.no_grad() guard, and get_autocast_context exposes fp16 without a paired GradScaler.
  • masked_mean: Returns (B, F) for rank-3 no-mask input but documents (B, 1, F).

Important Files Changed

Filename Overview
physicsnemo/experimental/models/aerojepa/aerojepa.py Top-level AeroJEPA Module composing trunk + predictor; forward/predict/decode_field_chunked paths look correct; build_target_token_coords uses a private _tokenize_single method (noqa-suppressed).
physicsnemo/experimental/models/aerojepa/trunk.py AeroJEPATrunk wiring encoder/decoder; encode_context, decode_queries, forward_single/forward_batch all look correct.
physicsnemo/experimental/models/aerojepa/decoder.py QueryTokenDecoder with chunked cross-attention, SIREN options, wall-velocity gate, and batched forward; logic appears sound.
physicsnemo/experimental/models/aerojepa/predictor.py PrototypeTokenJEPAHead with interleaved self/cross attention; batch handling and conditioning logic look correct.
physicsnemo/experimental/models/aerojepa/layers/token_utils.py Batch flattening, k-NN, and TokenSet utilities; masked_mean has a docstring/implementation shape inconsistency for rank-3 no-mask input (returns (B,F) not (B,1,F) as documented).
examples/cfd/external_aerodynamics/aerojepa/train.py Hydra training entry point; validation forward pass in _run_epoch builds unnecessary autograd graphs because there is no torch.no_grad() guard when is_train=False, wasting GPU memory.
examples/cfd/external_aerodynamics/aerojepa/src/training/runtime.py get_autocast_context enables fp16 autocast without a paired GradScaler; safe with the default bf16 config but could silently corrupt training if users set precision: fp16.

Comments Outside Diff (1)

  1. physicsnemo/experimental/models/aerojepa/layers/token_utils.py, line 1350-1356 (link)

    P2 masked_mean return shape mismatch between mask=None and mask≠None paths for rank-3 input

    The docstring states the function returns (B, 1, F) for rank-3 input, but the mask is None branch uses features.mean(dim=1) (no keepdim) and actually returns (B, F). The masked branch correctly uses keepdim=True and returns (B, 1, F). This inconsistency could cause silent shape mismatches if a caller passes rank-3 features without a mask and expects the documented (B, 1, F) layout.

Reviews (1): Last reviewed commit: "changelog: move SuperWing recipe bullet ..." | Re-trigger Greptile

Comment on lines +95 to +107
if precision_l == "fp16":
return torch.autocast(device_type="cuda", dtype=torch.float16)
return contextlib.nullcontext()


def build_lr_scheduler(
optimizer: torch.optim.Optimizer,
*,
name: str,
epochs: int,
steps_per_epoch: int,
warmup_epochs: float = 5.0,
warmup_steps: int | None = None,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 fp16 + autocast without GradScaler

get_autocast_context enables torch.autocast with torch.float16 when precision="fp16", but the training loop in train.py contains no torch.cuda.amp.GradScaler. Without the scaler, fp16 gradients can overflow to inf/nan and the optimizer step silently corrupts parameters. The default config uses bf16 (which shares fp32's dynamic range and doesn't need a scaler), so normal runs are unaffected, but any user who changes precision: fp16 in the training YAML will experience silent training failures.

Comment on lines +312 to +345
loss_cfg: DictConfig,
epoch: int,
max_batches: int | None,
) -> dict[str, float]:
is_train = optimizer is not None
model.train(is_train)

totals: dict[str, float] = {
"loss": 0.0,
"recon": 0.0,
"latent": 0.0,
"sigreg": 0.0,
}
n_samples = 0

for batch_idx, batch in enumerate(loader):
if max_batches is not None and batch_idx >= int(max_batches):
break
batch = move_batch_to_device(batch, device)
if is_train:
optimizer.zero_grad(set_to_none=True)

sample_losses: list[torch.Tensor] = []
for sample_idx in range(int(batch["context_pos"].shape[0])):
sample = _slice_batch_sample(batch, sample_idx)
with get_autocast_context(device, precision):
pred_field, pred_features, target_tokens, _, _ = _forward_sample(
model, sample
)
loss, parts = _compute_total_loss(
pred_field=pred_field,
query_target=sample["query_target"],
pred_features=pred_features,
target_tokens=target_tokens,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Validation forward pass builds unnecessary computation graphs

_run_epoch runs _forward_sample and _compute_total_loss without a torch.no_grad() guard when is_train=False. PyTorch therefore builds and retains the full autograd graph for every validation sample, but backward() is never called. The graph is held until sample_losses is reset at the next batch boundary, so peak extra memory is one batch's computation graph. With eval_batch_size=1 this is modest, but larger evaluation batch sizes or models could trigger OOM. Wrapping the inner loop body (or at minimum the _forward_sample call) with torch.no_grad() or torch.inference_mode() when not is_train would eliminate this overhead.

@peterdsharpe

Copy link
Copy Markdown
Collaborator

Hi @fgiral000, thanks for the PR! To keep PR size reviewable, would it be possible to:

a) split this PR up into two separate PRs, one of which adds the model ("PR 1"), and a later follow-on that adds the example ("PR 2").

b) In PR 1, please re-use shared PhysicsNeMo tooling where possible. (E.g., _gpu_knn.py should re-use existing KNN implementations in physicsnemo.nn.functional; conditioning MLPs should use FullyConnected, many losses duplicate existing code)

c) All functions should use jaxtyping annotations for tensor shapes. Please use Literal types for enumerations rather than str, etc.

d) In PR 2, if possible, please add AeroJEPA as an example within ./examples/external_aerodynamics/unified_external_aero_recipe/, rather than as a standalone aerojepa folder.

@peterdsharpe

Copy link
Copy Markdown
Collaborator

Actually, it might be worth splitting out a third PR as well for addition of the SuperWing dataset utils.

@mnabian mnabian self-requested a review June 1, 2026 18:40
Comment thread CHANGELOG.md Outdated
Comment thread CHANGELOG.md Outdated
Comment thread CHANGELOG.md Outdated
Comment thread CHANGELOG.md Outdated

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this used at all? I suggest deleting this file and the corresponding test_gpu_knn.py test if it's not used anywhere.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 2a64458 -- _gpu_knn.py and test_gpu_knn.py are deleted; nothing referenced them after the k-NN consolidation.

return flat_coords, flat_offset_coords, batch_ids


def chunked_knn_indices(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can reduce this to a thin wrapper over physicsnemo.nn.functional.knn. Call knn(points=key_coords, queries=query_coords, k=k·dilation), drop the returned distances, then apply the dilation stride + k_eff clamp in the wrapper. This deletes _chunked_knn_indices_cpu, _chunked_knn_indices_gpu, the AE_KNN_BACKEND env var, and the os/scipy imports from token_utils.py, while the 4 call sites keep their current signature. With this, AeroJEPA gains the cuML path and compile-safety for free.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 2a64458. chunked_knn_indices is now a thin wrapper: it calls knn(points=key_coords, queries=query_coords, k=k·dilation), drops the distances, and applies the dilation stride + k_eff clamp. _chunked_knn_indices_cpu/_gpu, the AE_KNN_BACKEND env var, and the os/scipy imports are gone, and the 4 call sites keep their signature. (The helper now lives in experimental/nn/point_utils.py after the layer move below.)

from jaxtyping import Float


class FourierPositionalEncoding(nn.Module):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use the new FourierPositionalEncoding class implemented here: #1695

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 5aba97f. Now uses physicsnemo.nn.FourierPositionalEmbedding from #1695 (merged). The local layer and its test are removed.

return x + out


class LocalPointTransformerBlock(nn.Module):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any of the layer that can be general-purpose and are not AeroJEPA-specific should live outside of the models folder in experimental/nn/. For example, it seems ResidualMLP, LocalPointTransformerBlock, LocalTokenCrossAttentionBlock are among the general-purpose layers and need to be moved.

@mnabian mnabian Jun 4, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move the general-purpose layers — ResidualMLP, the two Point-Transformer attention blocks, PointCloudTokenizer, and the generic batch/gather helpers in token_utils — to experimental/nn/, splitting token_utils so the TokenSet-coupled helpers stay. TokenSet/EncoderOutput/prototype_anchors stay as the AeroJEPA-specific contract.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in d1e883c. Moved ResidualMLP, LocalPointTransformerBlock, LocalTokenCrossAttentionBlock, and PointCloudTokenizer to physicsnemo.experimental.nn, and split token_utils so the generic batch/gather/mask/k-NN helpers go to experimental/nn/point_utils.py while the TokenSet-coupled ones stay. TokenSet / EncoderOutput / prototype_anchors remain as the AeroJEPA-specific contract.

from .token_utils import chunked_knn_indices


def _farthest_point_sampling(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use this new implementation in physicsnemo: #1696
Using warp implementation gives pretty nice perf improvements:

case torch (ms) warp (ms) speedup match
small-p1024-d3-k128 5.384 0.221 24.31× ok
medium-p4096-d3-k512 21.491 1.771 12.13× ok
large-p16384-d3-k1024 43.121 15.273 2.82× ok
batched-b4-p4096-d3-k512 21.407 1.772 12.08× ok

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 45b98c3. The tokenizer now calls physicsnemo.nn.functional.farthest_point_sampling (auto Warp/torch dispatch).

return positions, features


def _farthest_point_sampling(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a duplicate implementation of FPS, without random start, correct? See https://github.com/fgiral000/physicsnemo/blob/3a01a2cde7a9678249a128b3d5b575c1b3298488/physicsnemo/experimental/models/aerojepa/layers/point_tokenizer.py#L37. Consider removing this.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the duplicate in 82f4a4d. After the upstream FPS swap (45b98c3), a small seeded helper remains here only for offline anchor generation: it must be reproducible from a seed (k-means init + empty-cluster refill), and physicsnemo.nn.functional.farthest_point_sampling doesn't expose a seed for its random start. The runtime tokenizer uses the shared primitive.

else _make_conditioning_mlp(int(conditioning_dim), 3 * int(dim))
)
self.adaln_zero = bool(adaln_zero)
self.net = nn.Sequential(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use physicsnemo's MLP layer:

self.net = Mlp(
    in_features=int(dim),
    hidden_features=hidden,      # single int -> one hidden layer dim->hidden
    out_features=int(dim),
    act_layer=nn.GELU,
    drop=float(dropout),
    final_dropout=True,
)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in a531ff3 — converted to physicsnemo.nn.Mlp.

self.context_in = nn.Linear(self.token_dim, self.hidden_dim)
self.cond_proj = None
if self.cond_dim > 0:
self.cond_proj = nn.Sequential(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use physicsnemo's Mlp: Mlp(cond_dim, hidden, hidden, act_layer=nn.SiLU)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in a531ff3.

raise ValueError(
"gen_conditioning_dim must be provided when use_gen_conditioning=True."
)
self.gen_proj = nn.Sequential(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use physicsnemo's Mlp: Mlp(gen_dim, hidden, token_dim, act_layer=nn.SiLU)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in a531ff3.

+ (cond_dim if self.mask_head_use_cond else 0)
+ (1 if getattr(self.decoder, "use_sdf", True) else 0)
)
self.mask_head = nn.Sequential(

@mnabian mnabian Jun 5, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use physicsnemo's Mlp: Mlp(mask_in, [hidden, hidden], 1, act_layer=nn.SiLU)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in a531ff3.

omega_0=float(pressure_head_siren_omega0),
)
else:
self.p_head = nn.Sequential(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use physicsnemo's Mlp: Mlp(hidden, p_hidden, 1, act_layer=nn.SiLU)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in a531ff3.

self.update_mlps = nn.ModuleList()
for _ in range(self.num_layers):
self.msg_mlps.append(
nn.Sequential(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use physicsnemo's Mlp.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in a531ff3.

)
)
self.gate_mlps.append(
nn.Sequential(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use physicsnemo's Mlp.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in a531ff3.

)
)
self.update_mlps.append(
nn.Sequential(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use physicsnemo's Mlp.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in a531ff3.

nn.GELU(),
nn.Linear(self.dim, self.dim),
)
self.attn_proj = nn.Sequential(

@mnabian mnabian Jun 5, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use physicsnemo's Mlp: Mlp(dim, dim, num_heads, act_layer=nn.GELU)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in a531ff3.

self.q_proj = nn.Linear(self.dim, self.dim)
self.k_proj = nn.Linear(self.dim, self.dim)
self.v_proj = nn.Linear(self.dim, self.dim)
self.pos_proj = nn.Sequential(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use physicsnemo's Mlp: Mlp(3, dim, dim, act_layer=nn.GELU)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in a531ff3.

nn.GELU(),
nn.Linear(self.dim, self.dim),
)
self.attn_proj = nn.Sequential(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use physicsnemo's Mlp: Mlp(dim, dim, num_heads, act_layer=nn.GELU)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in a531ff3.

self.q_proj = nn.Linear(self.dim, self.dim)
self.k_proj = nn.Linear(self.dim, self.dim)
self.v_proj = nn.Linear(self.dim, self.dim)
self.pos_proj = nn.Sequential(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use physicsnemo's Mlp: Mlp(3, dim, dim, act_layer=nn.GELU)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in a531ff3.

)


class SineLayer(nn.Module):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Drop SineLayer, use physicsnemo.nn.SirenLayer; keep SirenHead as a thin local composition but build it from the library's SirenLayer instead of the local one.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in bb18fc8SineLayer is removed; SirenHead is kept as a thin local composition built from physicsnemo.nn.module.siren_layers.SirenLayer.

]
)
self.trunk = nn.Sequential(
nn.LayerNorm(int(token_dim)),

@mnabian mnabian Jun 5, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use PhysicsNeMo's TE-aware LayerNorm instead of torch.nn.LayerNorm to pick up Transformer Engine's high-performance LayerNorm (notably the faster backward pass).

from physicsnemo.nn.module.layer_norm import LayerNorm
...
LayerNorm(int(token_dim))

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 0ed9438 — now uses physicsnemo.nn.module.layer_norm.LayerNorm.

):
super().__init__()
hidden = max(1, int(mlp_ratio)) * int(dim)
self.norm = nn.LayerNorm(int(dim))

@mnabian mnabian Jun 5, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use PhysicsNeMo's TE-aware LayerNorm instead of torch.nn.LayerNorm to pick up Transformer Engine's high-performance LayerNorm (notably the faster backward pass).

from physicsnemo.nn.module.layer_norm import LayerNorm
...
self.norm = LayerNorm(int(dim))

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 0ed9438.

self.neighbor_k = int(neighbor_k)
self.dilation = int(max(1, dilation))
self.knn_chunk_size = int(knn_chunk_size)
self.norm = nn.LayerNorm(self.dim)

@mnabian mnabian Jun 5, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use PhysicsNeMo's TE-aware LayerNorm instead of torch.nn.LayerNorm to pick up Transformer Engine's high-performance LayerNorm (notably the faster backward pass).

from physicsnemo.nn.module.layer_norm import LayerNorm
...
self.norm = LayerNorm(self.dim)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 0ed9438.

self.head_dim = self.dim // self.num_heads
self.neighbor_k = int(neighbor_k)
self.knn_chunk_size = int(knn_chunk_size)
self.norm_q = nn.LayerNorm(self.dim)

@mnabian mnabian Jun 5, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use PhysicsNeMo's TE-aware LayerNorm instead of torch.nn.LayerNorm to pick up Transformer Engine's high-performance LayerNorm (notably the faster backward pass).

from physicsnemo.nn.module.layer_norm import LayerNorm
...
self.norm_q = LayerNorm(self.dim)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 0ed9438.

fgiral000 and others added 29 commits June 10, 2026 18:14
Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
…recipe to CHANGELOG

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
…est split

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
…sion bump

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
…IDIA#1690 review)

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
… strategies (PR NVIDIA#1690 review)

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
…ype_cluster (PR NVIDIA#1690 review)

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
…f torch.nn.LayerNorm (PR NVIDIA#1690 review)

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
…branch (PR NVIDIA#1690 review)

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
…NVIDIA#1690 review)

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
NVIDIA#1690 review)

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
…ranch (PR NVIDIA#1690 review)

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
…A#1690 review)

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
… rows (PR NVIDIA#1690 review)

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
…oint_tokenizer FPS (PR NVIDIA#1690 review)

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
…drop local SineLayer (PR NVIDIA#1690 review)

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
…ential MLPs (PR NVIDIA#1690 review)

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
…knn, drop local backends (PR NVIDIA#1690 review)

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
…e (PR NVIDIA#1690 review)

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
…de_geometry (PR NVIDIA#1690 review)

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
…im=0 (PR NVIDIA#1690 review)

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
…k stacks (PR NVIDIA#1690 review)

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
…r save/load round-trip (PR NVIDIA#1690 review)

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
The attention blocks, point-cloud tokenizer, Fourier positional encoding,
and the batch/gather/mask/k-NN helpers are domain-agnostic and reusable, so
they move from the model-specific layers package up to
physicsnemo.experimental.nn. The two TokenSet-coupled helpers
(pad_token_sets, trim_batched_tokens) stay in the layers package. Moved
symbols are re-exported from the layers package __init__ for
backward-compatible imports, and the tests are split to match.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Collapse the four verbose AeroJEPA bullets (model, losses, building blocks,
recipe) into two concise entries - one for the library additions.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
…sampling

Replace the local farthest-point-sampling implementation in the point-cloud
tokenizer with the shared physicsnemo.nn.functional primitive, which gains the
Warp-accelerated CUDA backend and torch.compile-safety for free and
auto-dispatches between Warp and a torch baseline.

Offline prototype-anchor generation keeps a small self-contained seeded FPS
helper: it must be reproducible from a seed for the k-means initialization and
empty-cluster refill, and the shared primitive does not expose a seed for its
random start.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
…bedding

Replace the local Fourier positional-encoding layer with the shared
physicsnemo.nn.FourierPositionalEmbedding, which provides the same
deterministic axis-wise (NeRF-style) embedding with a matching constructor
surface (in_dim, num_bands, include_input) and out_dim. The query decoder,
the JEPA predictor head, and the point encoder all construct it directly from
physicsnemo.nn; the local layer and its dedicated test are removed (the shared
layer carries its own tests).

The shared layer lays its features out axis-major rather than band-major; the
feature set is identical up to a fixed permutation that the immediately
following linear projection absorbs, so model outputs are unaffected beyond a
re-init of those projections. Model shape / property / checkpoint round-trip
tests are unchanged and pass.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Updated README.md for AeroJEPA tutorial for title case. And a few edits for clarity

(cherry picked from commit b0c7e19)
@fgiral000 fgiral000 force-pushed the aerojepa-integration branch from 4391d7d to 3637750 Compare June 10, 2026 23:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants