Skip to content

Commit 7ff2a2a

Browse files
coreyjadamsAlexey-Kamenevjeis4wpi
authored
Refactor (#1208)
* Move filesystems and version_check to core * Fix version check tests * Reorganize distributed, domain_parallel, and begin nn / utils cleanup. * Move modules and meta to core. Move registry to core. No tests fixed yet. * Add missing init files * Update build system and specify some deps. * Reorganize tests. * Update init files * Clean up neighbor tools. * Update testing * Fix compat tests * Move core model tests to tests/core/ * Add import lint config * Relocate layers * Move graphcast utils into model directory * Relocating util functionalities. * Add FIGConvNet to crash example (#1207) * Add FIGConvNet to crash example. * Add FIGConvNet to crash example * Update model config * propose fix some typos (#1209) Signed-off-by: John E <[email protected]> Co-authored-by: Corey adams <[email protected]> * Further clean up and organize tests. * utils tests are passing now * Cleaning up distributed tests * Patching tests working again in nn * Fix sdf test * Fix zenith angle tests * Some organization of tests. Checkpoints is moved into utils. * Remove launch.utils and launch.config. Checkpointing is moved to phsyicsnemo.utils, launch.config is just gone. It was empty. * Most nn tests are passing * Further cleanup. Getting there! * Remove constants file * Add import linting to pre-commit. --------- Signed-off-by: John E <[email protected]> Co-authored-by: Alexey Kamenev <[email protected]> Co-authored-by: John Eismeier <[email protected]>
1 parent 04d5fe9 commit 7ff2a2a

File tree

207 files changed

+1387
-1113
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

207 files changed

+1387
-1113
lines changed

.importlinter

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
[importlinter]
2+
root_package = physicsnemo
3+
include_external_packages = True
4+
5+
[importlinter:contract:physicsnemo-modules]
6+
name = Prevent Upward Imports in the PhysicsNemo Structure
7+
type = layers
8+
containers=
9+
physicsnemo
10+
layers =
11+
experimental
12+
models : registry : datapipes : metrics : domain_parallel
13+
utils
14+
distributed | nn
15+
core
16+
17+
[importlinter:contract:physicsnemo-core]
18+
name = Control Dependencies in PhysicsNeMo core
19+
type = layers
20+
containers=
21+
physicsnemo.core
22+
layers =
23+
module : registry
24+
meta
25+
warnings | version_check | filesystem
26+
27+
28+
[importlinter:contract:physicsnemo-distributed]
29+
name = Control Dependencies in PhysicsNeMo distributed
30+
type = layers
31+
containers=
32+
physicsnemo.distributed
33+
layers =
34+
fft | autograd
35+
mappings
36+
utils
37+
manager
38+
config
39+
40+
[importlinter:contract:physicsnemo-utils]
41+
name = Control Dependencies in PhysicsNeMo utils
42+
type = layers
43+
containers=
44+
physicsnemo.utils
45+
layers =
46+
mesh
47+
profiling
48+
checkpoint
49+
capture
50+
logging | memory
51+
52+
[importlinter:contract:physicsnemo-nn]
53+
name = Control Dependencies in PhysicsNeMo nn
54+
type = layers
55+
containers=
56+
physicsnemo.nn
57+
layers =
58+
fourier_layers | transformer_layers
59+
dgm_layers | mlp_layers | fully_connected_layers
60+
activations | attention_layers | ball_query | conv_layers | drop | fft | fused_silu | insolation | interpolation | kan_layers | patching | resample_layers | sdf | siren_layers | spectral_layers | transformer_decoder | weight_fact | weight_norm | zenith_angle
61+
neighbors
62+
utils
63+
64+
65+
[importlinter:contract:physicsnemo-models]
66+
name = Prevent Imports between physicsnemo models
67+
type = layers
68+
containers=
69+
physicsnemo.models
70+
layers =
71+
mesh_reduced
72+
afno | dlwp | dlwp_healpix | domino | dpot | fengwu | figconvnet | fno | graphcast | meshgraphnet | pangu | pix2pix | rnn | srrn | swinvrnn | topodiff | transolver | vfgn
73+
unet | diffusion | gnn_layers | dlwp_healpix_layers
74+
utils
75+

.pre-commit-config.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,8 @@ repos:
5353
hooks:
5454
- id: check-added-large-files
5555
args: [--maxkb=5000]
56+
57+
- repo: https://github.com/seddonym/import-linter
58+
rev: v2.5.2
59+
hooks:
60+
- id: import-linter

FAQ.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,15 @@
1111
## What is the recommended hardware for training using PhysicsNeMo framework?
1212

1313
Please refer to the recommended hardware section:
14-
[System Requirments](https://docs.nvidia.com/deeplearning/physicsnemo/getting-started/index.html#system-requirements)
14+
[System Requirements](https://docs.nvidia.com/deeplearning/physicsnemo/getting-started/index.html#system-requirements)
1515

1616
## What model architectures are in PhysicsNeMo?
1717

1818
Nvidia PhysicsNeMo is built on top of PyTorch and you can build and train any model
1919
architecture you want in PhysicsNeMo. PhysicsNeMo however has a catalog of models that
2020
have been packaged in a configurable form to make it easy to retrain with new data or certain
2121
config parameters. Examples include GNNs like MeshGraphNet or Neural Operators like FNO.
22-
PhysicsNeMo samples have more models that illustrate how a specific approach with a specifc
22+
PhysicsNeMo samples have more models that illustrate how a specific approach with a specific
2323
model architecture can be applied to a specific problem.
2424
These are reference starting points for users to get started.
2525

@@ -47,7 +47,7 @@ that illustrates the concept.
4747

4848
## What can I do if I dont see a PDE in PhysicsNeMo?
4949

50-
PhysicsNeMo Symbolic provides a well documeted
50+
PhysicsNeMo Symbolic provides a well documented
5151
[example](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-sym/user_guide/foundational/1d_wave_equation.html#writing-custom-pdes-and-boundary-initial-conditions)
5252
that walks you through how to define a custom PDE. Please see the source [here](https://github.com/NVIDIA/physicsnemo-sym/tree/main/physicsnemo/sym/eq/pdes)
5353
to see the built-in PDE implementation as an additional reference for your own implementation.
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
2+
# SPDX-FileCopyrightText: All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
_target_: rollout.FIGConvUNetTimeConditionalRollout
18+
_convert_: all
19+
20+
# Input/output channels
21+
in_channels: 2 # thickness + time
22+
out_channels: 3 # displacement offset (xyz)
23+
24+
# Architecture
25+
kernel_size: 3
26+
hidden_channels: [16, 16, 16] # channels at each level
27+
num_levels: 2 # number of down/up levels
28+
num_down_blocks: 1
29+
num_up_blocks: 1
30+
mlp_channels: [256, 256]
31+
32+
# Spatial domain
33+
aabb_max: [2.0, 2.0, 2.0]
34+
aabb_min: [-2.0, -2.0, -2.0]
35+
voxel_size: null
36+
37+
# Grid resolutions (factorized implicit grids)
38+
# Format: Uses res_mem_pair resolver (memory_format_enum, resolution_tuple)
39+
resolution_memory_format_pairs:
40+
- [b_xc_y_z, [2, 64, 64]]
41+
- [b_yc_x_z, [64, 2, 64]]
42+
- [b_zc_x_y, [64, 64, 2]]
43+
44+
# Position encoding
45+
use_rel_pos: true
46+
use_rel_pos_embed: true
47+
pos_encode_dim: 16
48+
49+
# Communication and sampling
50+
communication_types: ["sum"]
51+
to_point_sample_method: "graphconv"
52+
neighbor_search_type: "knn"
53+
knn_k: 16
54+
reductions: ["mean"]
55+
56+
# Pooling (for global features if needed)
57+
pooling_type: "max"
58+
pooling_layers: [2]
59+
60+
# Rollout parameters
61+
num_time_steps: ${training.num_time_steps}

examples/structural_mechanics/crash/datapipe.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -350,13 +350,6 @@ def _normalize_node_tensor(
350350
)
351351
return (invar - mu.view(1, 1, -1)) / (std.view(1, 1, -1) + EPS)
352352

353-
@staticmethod
354-
def _normalize_thickness_tensor(
355-
thickness: torch.Tensor, mu: torch.Tensor, std: torch.Tensor
356-
):
357-
# thickness: [N], mu/std: scalar tensors
358-
return (thickness - mu) / (std + EPS)
359-
360353

361354
class CrashGraphDataset(CrashBaseDataset):
362355
"""

examples/structural_mechanics/crash/rollout.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515
# limitations under the License.
1616

1717
import torch
18-
import torch.nn as nn
1918
from torch.utils.checkpoint import checkpoint as ckpt
2019

2120
from physicsnemo.models.transolver import Transolver
2221
from physicsnemo.models.meshgraphnet import MeshGraphNet
22+
from physicsnemo.models.figconvnet.figconvunet import FIGConvUNet
2323

2424
from datapipe import SimSample
2525

@@ -406,3 +406,64 @@ def step_fn(nf, ef, g):
406406
y_t0, y_t1 = y_t1, y_t2_pred
407407

408408
return torch.stack(outputs, dim=0) # [T,N,3]
409+
410+
411+
class FIGConvUNetTimeConditionalRollout(FIGConvUNet):
412+
"""
413+
FIGConvUNet with time-conditional rollout for crash simulation.
414+
415+
Predicts each time step independently, conditioned on normalized time.
416+
"""
417+
418+
def __init__(self, *args, **kwargs):
419+
self.rollout_steps: int = kwargs.pop("num_time_steps") - 1
420+
super().__init__(*args, **kwargs)
421+
422+
def forward(
423+
self,
424+
sample: SimSample,
425+
data_stats: dict,
426+
) -> torch.Tensor:
427+
"""
428+
Args:
429+
Sample: SimSample containing node_features and node_target
430+
data_stats: dict containing normalization stats
431+
Returns:
432+
[T, N, 3] rollout of predicted positions
433+
"""
434+
inputs = sample.node_features
435+
x = inputs["coords"] # initial pos [N, 3]
436+
features = inputs.get("features", x.new_zeros((x.size(0), 0))) # [N, F]
437+
438+
outputs: list[torch.Tensor] = []
439+
time_seq = torch.linspace(0.0, 1.0, self.rollout_steps, device=x.device)
440+
441+
for time_t in time_seq:
442+
# Prepare vertices for FIGConvUNet: [1, N, 3]
443+
vertices = x.unsqueeze(0) # [1, N, 3]
444+
445+
# Prepare features: features + time [N, F+1]
446+
time_expanded = time_t.expand(x.size(0), 1) # [N, 1]
447+
features_t = torch.cat([features, time_expanded], dim=-1) # [N, F+1]
448+
features_t = features_t.unsqueeze(0) # [1, N, F+1]
449+
450+
def step_fn(verts, feats):
451+
out, _ = super(FIGConvUNetTimeConditionalRollout, self).forward(
452+
vertices=verts, features=feats
453+
)
454+
return out
455+
456+
if self.training:
457+
outf = ckpt(
458+
step_fn,
459+
vertices,
460+
features_t,
461+
use_reentrant=False,
462+
).squeeze(0) # [N, 3]
463+
else:
464+
outf = step_fn(vertices, features_t).squeeze(0) # [N, 3]
465+
466+
y_t = x + outf
467+
outputs.append(y_t)
468+
469+
return torch.stack(outputs, dim=0) # [T, N, 3]

examples/structural_mechanics/crash/train.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
from torch.nn.parallel import DistributedDataParallel
3131
from torch.utils.data.distributed import DistributedSampler
3232
from torch.utils.tensorboard import SummaryWriter
33-
from tqdm import tqdm
3433

3534
from physicsnemo.distributed.manager import DistributedManager
3635
from physicsnemo.launch.logging import PythonLogger, RankZeroLoggingWrapper
@@ -64,6 +63,11 @@ def __init__(self, cfg: DictConfig, logger0: RankZeroLoggingWrapper):
6463
f"Model {model_name} requires a point-cloud datapipe, "
6564
f"but you selected {datapipe_name}."
6665
)
66+
if "FIGConvUNet" in model_name and "PointCloudDataset" not in datapipe_name:
67+
raise ValueError(
68+
f"Model {model_name} requires a point-cloud datapipe, "
69+
f"but you selected {datapipe_name}."
70+
)
6771

6872
# Dataset
6973
reader = instantiate(cfg.reader)
@@ -223,7 +227,7 @@ def main(cfg: DictConfig) -> None:
223227
for sample in trainer.dataloader:
224228
sample = sample[0].to(dist.device) # SimSample .to()
225229
loss = trainer.train(sample)
226-
total_loss += loss.item()
230+
total_loss += loss.detach().item()
227231
num_batches += 1
228232

229233
trainer.scheduler.step()

examples/structural_mechanics/crash/vtp_reader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def process_vtp_data(data_dir, num_samples=2, write_vtp=False, logger=None):
169169

170170
if not vtp_files:
171171
if logger:
172-
logger.error("No .vtp files found in:", base_data_dir)
172+
logger.error(f"No .vtp files found in: {base_data_dir}")
173173
exit(1)
174174

175175
for vtp_path in vtp_files:

physicsnemo/__init__.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,26 @@
1313
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
16+
import os
1617

17-
from .datapipes.datapipe import Datapipe
18-
from .datapipes.meta import DatapipeMetaData
19-
from .models.meta import ModelMetaData
20-
from .models.module import Module
18+
# Backwards-compatibility is opt-in. Enable with env var or via enable_compat().
19+
if os.getenv("PHYSICSNEMO_ENABLE_COMPAT") in {
20+
"1",
21+
"true",
22+
"True",
23+
"YES",
24+
"yes",
25+
"on",
26+
"ON",
27+
}:
28+
from .compat import install as _compat_install
29+
30+
_compat_install()
31+
32+
33+
# from .datapipes.datapipe import Datapipe # noqa E402
34+
# from .datapipes.meta import DatapipeMetaData # noqa E402
35+
# from .core.meta import ModelMetaData # noqa E402
36+
# from .core.module import Module # noqa E402
2137

2238
__version__ = "1.3.0a0"

0 commit comments

Comments
 (0)