Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

_target_: rollout.FIGConvUNetTimeConditionalRollout
_convert_: all

# Input/output channels
in_channels: 2 # thickness + time
out_channels: 3 # displacement offset (xyz)

# Architecture
kernel_size: 3
hidden_channels: [16, 16, 16] # channels at each level
num_levels: 2 # number of down/up levels
num_down_blocks: 1
num_up_blocks: 1
mlp_channels: [256, 256]

# Spatial domain
aabb_max: [2.0, 2.0, 2.0]
aabb_min: [-2.0, -2.0, -2.0]
voxel_size: null

# Grid resolutions (factorized implicit grids)
# Format: Uses res_mem_pair resolver (memory_format_enum, resolution_tuple)
resolution_memory_format_pairs:
- [b_xc_y_z, [2, 64, 64]]
- [b_yc_x_z, [64, 2, 64]]
- [b_zc_x_y, [64, 64, 2]]

# Position encoding
use_rel_pos: true
use_rel_pos_embed: true
pos_encode_dim: 16

# Communication and sampling
communication_types: ["sum"]
to_point_sample_method: "graphconv"
neighbor_search_type: "knn"
knn_k: 16
reductions: ["mean"]

# Pooling (for global features if needed)
pooling_type: "max"
pooling_layers: [2]

# Rollout parameters
num_time_steps: ${training.num_time_steps}
7 changes: 0 additions & 7 deletions examples/structural_mechanics/crash/datapipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,13 +350,6 @@ def _normalize_node_tensor(
)
return (invar - mu.view(1, 1, -1)) / (std.view(1, 1, -1) + EPS)

@staticmethod
def _normalize_thickness_tensor(
thickness: torch.Tensor, mu: torch.Tensor, std: torch.Tensor
):
# thickness: [N], mu/std: scalar tensors
return (thickness - mu) / (std + EPS)


class CrashGraphDataset(CrashBaseDataset):
"""
Expand Down
63 changes: 62 additions & 1 deletion examples/structural_mechanics/crash/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
# limitations under the License.

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint as ckpt

from physicsnemo.models.transolver import Transolver
from physicsnemo.models.meshgraphnet import MeshGraphNet
from physicsnemo.models.figconvnet.figconvunet import FIGConvUNet

from datapipe import SimSample

Expand Down Expand Up @@ -406,3 +406,64 @@ def step_fn(nf, ef, g):
y_t0, y_t1 = y_t1, y_t2_pred

return torch.stack(outputs, dim=0) # [T,N,3]


class FIGConvUNetTimeConditionalRollout(FIGConvUNet):
"""
FIGConvUNet with time-conditional rollout for crash simulation.

Predicts each time step independently, conditioned on normalized time.
"""

def __init__(self, *args, **kwargs):
self.rollout_steps: int = kwargs.pop("num_time_steps") - 1
super().__init__(*args, **kwargs)

def forward(
self,
sample: SimSample,
data_stats: dict,
) -> torch.Tensor:
"""
Args:
Sample: SimSample containing node_features and node_target
data_stats: dict containing normalization stats
Returns:
[T, N, 3] rollout of predicted positions
"""
inputs = sample.node_features
x = inputs["coords"] # initial pos [N, 3]
features = inputs.get("features", x.new_zeros((x.size(0), 0))) # [N, F]

outputs: list[torch.Tensor] = []
time_seq = torch.linspace(0.0, 1.0, self.rollout_steps, device=x.device)

for time_t in time_seq:
# Prepare vertices for FIGConvUNet: [1, N, 3]
vertices = x.unsqueeze(0) # [1, N, 3]

# Prepare features: features + time [N, F+1]
time_expanded = time_t.expand(x.size(0), 1) # [N, 1]
features_t = torch.cat([features, time_expanded], dim=-1) # [N, F+1]
features_t = features_t.unsqueeze(0) # [1, N, F+1]

def step_fn(verts, feats):
out, _ = super(FIGConvUNetTimeConditionalRollout, self).forward(
vertices=verts, features=feats
)
return out

if self.training:
outf = ckpt(
step_fn,
vertices,
features_t,
use_reentrant=False,
).squeeze(0) # [N, 3]
else:
outf = step_fn(vertices, features_t).squeeze(0) # [N, 3]

y_t = x + outf
outputs.append(y_t)

return torch.stack(outputs, dim=0) # [T, N, 3]
8 changes: 6 additions & 2 deletions examples/structural_mechanics/crash/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data.distributed import DistributedSampler
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from physicsnemo.distributed.manager import DistributedManager
from physicsnemo.launch.logging import PythonLogger, RankZeroLoggingWrapper
Expand Down Expand Up @@ -64,6 +63,11 @@ def __init__(self, cfg: DictConfig, logger0: RankZeroLoggingWrapper):
f"Model {model_name} requires a point-cloud datapipe, "
f"but you selected {datapipe_name}."
)
if "FIGConvUNet" in model_name and "PointCloudDataset" not in datapipe_name:
raise ValueError(
f"Model {model_name} requires a point-cloud datapipe, "
f"but you selected {datapipe_name}."
)

# Dataset
reader = instantiate(cfg.reader)
Expand Down Expand Up @@ -223,7 +227,7 @@ def main(cfg: DictConfig) -> None:
for sample in trainer.dataloader:
sample = sample[0].to(dist.device) # SimSample .to()
loss = trainer.train(sample)
total_loss += loss.item()
total_loss += loss.detach().item()
num_batches += 1

trainer.scheduler.step()
Expand Down
2 changes: 1 addition & 1 deletion examples/structural_mechanics/crash/vtp_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def process_vtp_data(data_dir, num_samples=2, write_vtp=False, logger=None):

if not vtp_files:
if logger:
logger.error("No .vtp files found in:", base_data_dir)
logger.error(f"No .vtp files found in: {base_data_dir}")
exit(1)

for vtp_path in vtp_files:
Expand Down