Add AeroJEPA model + SuperWing tutorial recipe (experimental)#1690
Add AeroJEPA model + SuperWing tutorial recipe (experimental)#1690fgiral000 wants to merge 74 commits into
Conversation
Greptile SummaryThis PR introduces the AeroJEPA model — a Joint-Embedding Predictive Architecture for 3D aerodynamic surrogate modeling — along with all its building blocks (encoders, decoder, predictor, layers, losses) under
Important Files Changed
|
| if precision_l == "fp16": | ||
| return torch.autocast(device_type="cuda", dtype=torch.float16) | ||
| return contextlib.nullcontext() | ||
|
|
||
|
|
||
| def build_lr_scheduler( | ||
| optimizer: torch.optim.Optimizer, | ||
| *, | ||
| name: str, | ||
| epochs: int, | ||
| steps_per_epoch: int, | ||
| warmup_epochs: float = 5.0, | ||
| warmup_steps: int | None = None, |
There was a problem hiding this comment.
fp16 + autocast without GradScaler
get_autocast_context enables torch.autocast with torch.float16 when precision="fp16", but the training loop in train.py contains no torch.cuda.amp.GradScaler. Without the scaler, fp16 gradients can overflow to inf/nan and the optimizer step silently corrupts parameters. The default config uses bf16 (which shares fp32's dynamic range and doesn't need a scaler), so normal runs are unaffected, but any user who changes precision: fp16 in the training YAML will experience silent training failures.
| loss_cfg: DictConfig, | ||
| epoch: int, | ||
| max_batches: int | None, | ||
| ) -> dict[str, float]: | ||
| is_train = optimizer is not None | ||
| model.train(is_train) | ||
|
|
||
| totals: dict[str, float] = { | ||
| "loss": 0.0, | ||
| "recon": 0.0, | ||
| "latent": 0.0, | ||
| "sigreg": 0.0, | ||
| } | ||
| n_samples = 0 | ||
|
|
||
| for batch_idx, batch in enumerate(loader): | ||
| if max_batches is not None and batch_idx >= int(max_batches): | ||
| break | ||
| batch = move_batch_to_device(batch, device) | ||
| if is_train: | ||
| optimizer.zero_grad(set_to_none=True) | ||
|
|
||
| sample_losses: list[torch.Tensor] = [] | ||
| for sample_idx in range(int(batch["context_pos"].shape[0])): | ||
| sample = _slice_batch_sample(batch, sample_idx) | ||
| with get_autocast_context(device, precision): | ||
| pred_field, pred_features, target_tokens, _, _ = _forward_sample( | ||
| model, sample | ||
| ) | ||
| loss, parts = _compute_total_loss( | ||
| pred_field=pred_field, | ||
| query_target=sample["query_target"], | ||
| pred_features=pred_features, | ||
| target_tokens=target_tokens, |
There was a problem hiding this comment.
Validation forward pass builds unnecessary computation graphs
_run_epoch runs _forward_sample and _compute_total_loss without a torch.no_grad() guard when is_train=False. PyTorch therefore builds and retains the full autograd graph for every validation sample, but backward() is never called. The graph is held until sample_losses is reset at the next batch boundary, so peak extra memory is one batch's computation graph. With eval_batch_size=1 this is modest, but larger evaluation batch sizes or models could trigger OOM. Wrapping the inner loop body (or at minimum the _forward_sample call) with torch.no_grad() or torch.inference_mode() when not is_train would eliminate this overhead.
|
Hi @fgiral000, thanks for the PR! To keep PR size reviewable, would it be possible to: a) split this PR up into two separate PRs, one of which adds the model ("PR 1"), and a later follow-on that adds the example ("PR 2"). b) In PR 1, please re-use shared PhysicsNeMo tooling where possible. (E.g., c) All functions should use d) In PR 2, if possible, please add AeroJEPA as an example within |
|
Actually, it might be worth splitting out a third PR as well for addition of the SuperWing dataset utils. |
There was a problem hiding this comment.
Is this used at all? I suggest deleting this file and the corresponding test_gpu_knn.py test if it's not used anywhere.
There was a problem hiding this comment.
Done in 2a64458 -- _gpu_knn.py and test_gpu_knn.py are deleted; nothing referenced them after the k-NN consolidation.
| return flat_coords, flat_offset_coords, batch_ids | ||
|
|
||
|
|
||
| def chunked_knn_indices( |
There was a problem hiding this comment.
You can reduce this to a thin wrapper over physicsnemo.nn.functional.knn. Call knn(points=key_coords, queries=query_coords, k=k·dilation), drop the returned distances, then apply the dilation stride + k_eff clamp in the wrapper. This deletes _chunked_knn_indices_cpu, _chunked_knn_indices_gpu, the AE_KNN_BACKEND env var, and the os/scipy imports from token_utils.py, while the 4 call sites keep their current signature. With this, AeroJEPA gains the cuML path and compile-safety for free.
There was a problem hiding this comment.
Done in 2a64458. chunked_knn_indices is now a thin wrapper: it calls knn(points=key_coords, queries=query_coords, k=k·dilation), drops the distances, and applies the dilation stride + k_eff clamp. _chunked_knn_indices_cpu/_gpu, the AE_KNN_BACKEND env var, and the os/scipy imports are gone, and the 4 call sites keep their signature. (The helper now lives in experimental/nn/point_utils.py after the layer move below.)
| from jaxtyping import Float | ||
|
|
||
|
|
||
| class FourierPositionalEncoding(nn.Module): |
There was a problem hiding this comment.
Please use the new FourierPositionalEncoding class implemented here: #1695
| return x + out | ||
|
|
||
|
|
||
| class LocalPointTransformerBlock(nn.Module): |
There was a problem hiding this comment.
Any of the layer that can be general-purpose and are not AeroJEPA-specific should live outside of the models folder in experimental/nn/. For example, it seems ResidualMLP, LocalPointTransformerBlock, LocalTokenCrossAttentionBlock are among the general-purpose layers and need to be moved.
There was a problem hiding this comment.
Move the general-purpose layers — ResidualMLP, the two Point-Transformer attention blocks, PointCloudTokenizer, and the generic batch/gather helpers in token_utils — to experimental/nn/, splitting token_utils so the TokenSet-coupled helpers stay. TokenSet/EncoderOutput/prototype_anchors stay as the AeroJEPA-specific contract.
There was a problem hiding this comment.
Done in d1e883c. Moved ResidualMLP, LocalPointTransformerBlock, LocalTokenCrossAttentionBlock, and PointCloudTokenizer to physicsnemo.experimental.nn, and split token_utils so the generic batch/gather/mask/k-NN helpers go to experimental/nn/point_utils.py while the TokenSet-coupled ones stay. TokenSet / EncoderOutput / prototype_anchors remain as the AeroJEPA-specific contract.
| from .token_utils import chunked_knn_indices | ||
|
|
||
|
|
||
| def _farthest_point_sampling( |
There was a problem hiding this comment.
Please use this new implementation in physicsnemo: #1696
Using warp implementation gives pretty nice perf improvements:
| case | torch (ms) | warp (ms) | speedup | match |
|---|---|---|---|---|
| small-p1024-d3-k128 | 5.384 | 0.221 | 24.31× | ok |
| medium-p4096-d3-k512 | 21.491 | 1.771 | 12.13× | ok |
| large-p16384-d3-k1024 | 43.121 | 15.273 | 2.82× | ok |
| batched-b4-p4096-d3-k512 | 21.407 | 1.772 | 12.08× | ok |
There was a problem hiding this comment.
Done in 45b98c3. The tokenizer now calls physicsnemo.nn.functional.farthest_point_sampling (auto Warp/torch dispatch).
| return positions, features | ||
|
|
||
|
|
||
| def _farthest_point_sampling( |
There was a problem hiding this comment.
This is a duplicate implementation of FPS, without random start, correct? See https://github.com/fgiral000/physicsnemo/blob/3a01a2cde7a9678249a128b3d5b575c1b3298488/physicsnemo/experimental/models/aerojepa/layers/point_tokenizer.py#L37. Consider removing this.
There was a problem hiding this comment.
Removed the duplicate in 82f4a4d. After the upstream FPS swap (45b98c3), a small seeded helper remains here only for offline anchor generation: it must be reproducible from a seed (k-means init + empty-cluster refill), and physicsnemo.nn.functional.farthest_point_sampling doesn't expose a seed for its random start. The runtime tokenizer uses the shared primitive.
| else _make_conditioning_mlp(int(conditioning_dim), 3 * int(dim)) | ||
| ) | ||
| self.adaln_zero = bool(adaln_zero) | ||
| self.net = nn.Sequential( |
There was a problem hiding this comment.
Please use physicsnemo's MLP layer:
self.net = Mlp(
in_features=int(dim),
hidden_features=hidden, # single int -> one hidden layer dim->hidden
out_features=int(dim),
act_layer=nn.GELU,
drop=float(dropout),
final_dropout=True,
)
| self.context_in = nn.Linear(self.token_dim, self.hidden_dim) | ||
| self.cond_proj = None | ||
| if self.cond_dim > 0: | ||
| self.cond_proj = nn.Sequential( |
There was a problem hiding this comment.
Please use physicsnemo's Mlp: Mlp(cond_dim, hidden, hidden, act_layer=nn.SiLU)
| raise ValueError( | ||
| "gen_conditioning_dim must be provided when use_gen_conditioning=True." | ||
| ) | ||
| self.gen_proj = nn.Sequential( |
There was a problem hiding this comment.
Please use physicsnemo's Mlp: Mlp(gen_dim, hidden, token_dim, act_layer=nn.SiLU)
| + (cond_dim if self.mask_head_use_cond else 0) | ||
| + (1 if getattr(self.decoder, "use_sdf", True) else 0) | ||
| ) | ||
| self.mask_head = nn.Sequential( |
There was a problem hiding this comment.
Please use physicsnemo's Mlp: Mlp(mask_in, [hidden, hidden], 1, act_layer=nn.SiLU)
| omega_0=float(pressure_head_siren_omega0), | ||
| ) | ||
| else: | ||
| self.p_head = nn.Sequential( |
There was a problem hiding this comment.
Please use physicsnemo's Mlp: Mlp(hidden, p_hidden, 1, act_layer=nn.SiLU)
| self.update_mlps = nn.ModuleList() | ||
| for _ in range(self.num_layers): | ||
| self.msg_mlps.append( | ||
| nn.Sequential( |
There was a problem hiding this comment.
Please use physicsnemo's Mlp.
| ) | ||
| ) | ||
| self.gate_mlps.append( | ||
| nn.Sequential( |
There was a problem hiding this comment.
Please use physicsnemo's Mlp.
| ) | ||
| ) | ||
| self.update_mlps.append( | ||
| nn.Sequential( |
There was a problem hiding this comment.
Please use physicsnemo's Mlp.
| nn.GELU(), | ||
| nn.Linear(self.dim, self.dim), | ||
| ) | ||
| self.attn_proj = nn.Sequential( |
There was a problem hiding this comment.
Please use physicsnemo's Mlp: Mlp(dim, dim, num_heads, act_layer=nn.GELU)
| self.q_proj = nn.Linear(self.dim, self.dim) | ||
| self.k_proj = nn.Linear(self.dim, self.dim) | ||
| self.v_proj = nn.Linear(self.dim, self.dim) | ||
| self.pos_proj = nn.Sequential( |
There was a problem hiding this comment.
Please use physicsnemo's Mlp: Mlp(3, dim, dim, act_layer=nn.GELU)
| nn.GELU(), | ||
| nn.Linear(self.dim, self.dim), | ||
| ) | ||
| self.attn_proj = nn.Sequential( |
There was a problem hiding this comment.
Please use physicsnemo's Mlp: Mlp(dim, dim, num_heads, act_layer=nn.GELU)
| self.q_proj = nn.Linear(self.dim, self.dim) | ||
| self.k_proj = nn.Linear(self.dim, self.dim) | ||
| self.v_proj = nn.Linear(self.dim, self.dim) | ||
| self.pos_proj = nn.Sequential( |
There was a problem hiding this comment.
Please use physicsnemo's Mlp: Mlp(3, dim, dim, act_layer=nn.GELU)
| ) | ||
|
|
||
|
|
||
| class SineLayer(nn.Module): |
There was a problem hiding this comment.
Drop SineLayer, use physicsnemo.nn.SirenLayer; keep SirenHead as a thin local composition but build it from the library's SirenLayer instead of the local one.
There was a problem hiding this comment.
Done in bb18fc8 — SineLayer is removed; SirenHead is kept as a thin local composition built from physicsnemo.nn.module.siren_layers.SirenLayer.
| ] | ||
| ) | ||
| self.trunk = nn.Sequential( | ||
| nn.LayerNorm(int(token_dim)), |
There was a problem hiding this comment.
Use PhysicsNeMo's TE-aware LayerNorm instead of torch.nn.LayerNorm to pick up Transformer Engine's high-performance LayerNorm (notably the faster backward pass).
from physicsnemo.nn.module.layer_norm import LayerNorm
...
LayerNorm(int(token_dim))
There was a problem hiding this comment.
Done in 0ed9438 — now uses physicsnemo.nn.module.layer_norm.LayerNorm.
| ): | ||
| super().__init__() | ||
| hidden = max(1, int(mlp_ratio)) * int(dim) | ||
| self.norm = nn.LayerNorm(int(dim)) |
There was a problem hiding this comment.
Use PhysicsNeMo's TE-aware LayerNorm instead of torch.nn.LayerNorm to pick up Transformer Engine's high-performance LayerNorm (notably the faster backward pass).
from physicsnemo.nn.module.layer_norm import LayerNorm
...
self.norm = LayerNorm(int(dim))
| self.neighbor_k = int(neighbor_k) | ||
| self.dilation = int(max(1, dilation)) | ||
| self.knn_chunk_size = int(knn_chunk_size) | ||
| self.norm = nn.LayerNorm(self.dim) |
There was a problem hiding this comment.
Use PhysicsNeMo's TE-aware LayerNorm instead of torch.nn.LayerNorm to pick up Transformer Engine's high-performance LayerNorm (notably the faster backward pass).
from physicsnemo.nn.module.layer_norm import LayerNorm
...
self.norm = LayerNorm(self.dim)
| self.head_dim = self.dim // self.num_heads | ||
| self.neighbor_k = int(neighbor_k) | ||
| self.knn_chunk_size = int(knn_chunk_size) | ||
| self.norm_q = nn.LayerNorm(self.dim) |
There was a problem hiding this comment.
Use PhysicsNeMo's TE-aware LayerNorm instead of torch.nn.LayerNorm to pick up Transformer Engine's high-performance LayerNorm (notably the faster backward pass).
from physicsnemo.nn.module.layer_norm import LayerNorm
...
self.norm_q = LayerNorm(self.dim)
Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
…recipe to CHANGELOG Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
…est split Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
…sion bump Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
…IDIA#1690 review) Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
… strategies (PR NVIDIA#1690 review) Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
…ype_cluster (PR NVIDIA#1690 review) Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
…f torch.nn.LayerNorm (PR NVIDIA#1690 review) Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
…branch (PR NVIDIA#1690 review) Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
…NVIDIA#1690 review) Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
NVIDIA#1690 review) Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
…ranch (PR NVIDIA#1690 review) Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
…A#1690 review) Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
… rows (PR NVIDIA#1690 review) Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
…oint_tokenizer FPS (PR NVIDIA#1690 review) Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
…drop local SineLayer (PR NVIDIA#1690 review) Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
…ential MLPs (PR NVIDIA#1690 review) Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
…knn, drop local backends (PR NVIDIA#1690 review) Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
…e (PR NVIDIA#1690 review) Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
…de_geometry (PR NVIDIA#1690 review) Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
…im=0 (PR NVIDIA#1690 review) Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
…k stacks (PR NVIDIA#1690 review) Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
…r save/load round-trip (PR NVIDIA#1690 review) Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
The attention blocks, point-cloud tokenizer, Fourier positional encoding, and the batch/gather/mask/k-NN helpers are domain-agnostic and reusable, so they move from the model-specific layers package up to physicsnemo.experimental.nn. The two TokenSet-coupled helpers (pad_token_sets, trim_batched_tokens) stay in the layers package. Moved symbols are re-exported from the layers package __init__ for backward-compatible imports, and the tests are split to match. Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Collapse the four verbose AeroJEPA bullets (model, losses, building blocks, recipe) into two concise entries - one for the library additions. Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
…sampling Replace the local farthest-point-sampling implementation in the point-cloud tokenizer with the shared physicsnemo.nn.functional primitive, which gains the Warp-accelerated CUDA backend and torch.compile-safety for free and auto-dispatches between Warp and a torch baseline. Offline prototype-anchor generation keeps a small self-contained seeded FPS helper: it must be reproducible from a seed for the k-means initialization and empty-cluster refill, and the shared primitive does not expose a seed for its random start. Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
…bedding Replace the local Fourier positional-encoding layer with the shared physicsnemo.nn.FourierPositionalEmbedding, which provides the same deterministic axis-wise (NeRF-style) embedding with a matching constructor surface (in_dim, num_bands, include_input) and out_dim. The query decoder, the JEPA predictor head, and the point encoder all construct it directly from physicsnemo.nn; the local layer and its dedicated test are removed (the shared layer carries its own tests). The shared layer lays its features out axis-major rather than band-major; the feature set is identical up to a fixed permutation that the immediately following linear projection absorbs, so model outputs are unaffected beyond a re-init of those projections. Model shape / property / checkpoint round-trip tests are unchanged and pass. Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Updated README.md for AeroJEPA tutorial for title case. And a few edits for clarity (cherry picked from commit b0c7e19)
4391d7d to
3637750
Compare
PhysicsNeMo Pull Request
Description
Adds the AeroJEPA model and a SuperWing tutorial recipe under
physicsnemo.experimentalandexamples/cfd/external_aerodynamics/.AeroJEPA is a Joint-Embedding Predictive Architecture for 3D
aerodynamic surrogate modeling: instead of mapping geometry directly to
a flow field, it predicts a latent representation of the flow from a
latent representation of the geometry and operating conditions, and
reconstructs the field through a continuous implicit decoder when
needed (Giral et al., arXiv:2605.05586).
What this PR delivers:
physicsnemo.experimental.models.aerojepa.AeroJEPAcomposes a context encoder, a target encoder, a query-tokenfield decoder (collectively
AeroJEPATrunk), and a JEPA predictorhead (
PrototypeTokenJEPAHead) into a singlephysicsnemo.core.module.Module. The training path takes contextpositions/features, independent target encoder surface/volume inputs,
and operating conditions; the predictor predicts target tokens, and
the decoder evaluates the field at user-supplied query points.
predictis a no-grad inference wrapper;decode_field_chunkedsupports memory-bounded evaluation over very large query sets.
Concrete encoders (
ContextTransformer,TargetTransformer,PointTransformer), theQueryTokenDecoder, and the encoder ABCsare all exposed as composable components.
physicsnemo.experimental.models.aerojepa.layers.TokenSetandEncoderOutputtoken dataclasses, a deterministicFourierPositionalEncoding,ResidualMLP, theLocalPointTransformerBlock/LocalTokenCrossAttentionBlockattention blocks (with optional AdaLN / AdaLN-Zero conditioning), the
PointCloudTokenizer(seven center-selection strategies with k-NNcluster pooling), token batching / mask / k-NN helpers, and prototype
anchor build / load utilities.
TokenSetandEncoderOutputarere-exported from the model package for convenience.
physicsnemo.experimental.models.aerojepa.losses.SIGRegandTokenLatentSIGReg(a sketch isotropic-Gaussianregularizer for latent-token distributions, with a padding-aware
wrapper), the
flatten_valid_token_features/reshape_token_features_for_sigregmasking helpers, and thereconstruction loss family (
MSELoss/RelativeL2Loss/RelativeMSELoss/RelativeL2MSELoss, each with functional andnn.Moduleforms, optional per-channel weights stored as apersistent buffer, optional per-point weights, and an optional
validity mask).
examples/cfd/external_aerodynamics/aerojepa. End-to-end Hydra-drivenworkflow on the public SuperWing dataset (Yang et al.,
arXiv:2512.14397): dataset download via the Hugging Face Hub
(
yunplus/SuperWing), automatic split-by-geometry manifest andper-channel normalization stats, JEPA training (reconstruction +
latent + SIGReg with linear warmups; AdamW +
warmup-cosine; optional EMA), checkpointed inference with chunked
decoding, three-panel
GT | Pred | |Error|field plots for the threesurface channels (
Cp,Cf_tau,Cf_z), per-channel relative-L2 /RMSE / MAE metrics on the test split, and a pressure-only CL/CD
post-processor that integrates the surface field and emits a per-case
CSV plus a parity scatter.
Checklist
Tests
test/experimental/models/aerojepa/(constructor + attribute checks, non-regression shape checks on the
encoders, decoder, predictor, trunk, top-level model, layers, and
losses).
pytest test/experimental/models/aerojepa/ -qpasseslocally on CPU (~20 s).
train.py -> inference.py -> superwing_metrics -> superwing_forces.Training losses decrease monotonically; inference produces field
plots, per-case field-error metrics, and a force-coefficient parity
scatter.
Dependencies
No new core dependencies. The example recipe adds optional
example-side dependencies in
examples/cfd/external_aerodynamics/aerojepa/requirements.txt(Hugging Face Hub for the dataset download, plotting and
post-processing utilities). Pre-commit hooks,
ruff,interrogate,markdownlint, and the SPDX license check pass on every file in thePR.