|
15 | 15 | # limitations under the License. |
16 | 16 |
|
17 | 17 | import torch |
18 | | -import torch.nn as nn |
19 | 18 | from torch.utils.checkpoint import checkpoint as ckpt |
20 | 19 |
|
21 | 20 | from physicsnemo.models.transolver import Transolver |
22 | 21 | from physicsnemo.models.meshgraphnet import MeshGraphNet |
| 22 | +from physicsnemo.models.figconvnet.figconvunet import FIGConvUNet |
23 | 23 |
|
24 | 24 | from datapipe import SimSample |
25 | 25 |
|
@@ -406,3 +406,64 @@ def step_fn(nf, ef, g): |
406 | 406 | y_t0, y_t1 = y_t1, y_t2_pred |
407 | 407 |
|
408 | 408 | return torch.stack(outputs, dim=0) # [T,N,3] |
| 409 | + |
| 410 | + |
| 411 | +class FIGConvUNetTimeConditionalRollout(FIGConvUNet): |
| 412 | + """ |
| 413 | + FIGConvUNet with time-conditional rollout for crash simulation. |
| 414 | +
|
| 415 | + Predicts each time step independently, conditioned on normalized time. |
| 416 | + """ |
| 417 | + |
| 418 | + def __init__(self, *args, **kwargs): |
| 419 | + self.rollout_steps: int = kwargs.pop("num_time_steps") - 1 |
| 420 | + super().__init__(*args, **kwargs) |
| 421 | + |
| 422 | + def forward( |
| 423 | + self, |
| 424 | + sample: SimSample, |
| 425 | + data_stats: dict, |
| 426 | + ) -> torch.Tensor: |
| 427 | + """ |
| 428 | + Args: |
| 429 | + Sample: SimSample containing node_features and node_target |
| 430 | + data_stats: dict containing normalization stats |
| 431 | + Returns: |
| 432 | + [T, N, 3] rollout of predicted positions |
| 433 | + """ |
| 434 | + inputs = sample.node_features |
| 435 | + x = inputs["coords"] # initial pos [N, 3] |
| 436 | + features = inputs.get("features", x.new_zeros((x.size(0), 0))) # [N, F] |
| 437 | + |
| 438 | + outputs: list[torch.Tensor] = [] |
| 439 | + time_seq = torch.linspace(0.0, 1.0, self.rollout_steps, device=x.device) |
| 440 | + |
| 441 | + for time_t in time_seq: |
| 442 | + # Prepare vertices for FIGConvUNet: [1, N, 3] |
| 443 | + vertices = x.unsqueeze(0) # [1, N, 3] |
| 444 | + |
| 445 | + # Prepare features: features + time [N, F+1] |
| 446 | + time_expanded = time_t.expand(x.size(0), 1) # [N, 1] |
| 447 | + features_t = torch.cat([features, time_expanded], dim=-1) # [N, F+1] |
| 448 | + features_t = features_t.unsqueeze(0) # [1, N, F+1] |
| 449 | + |
| 450 | + def step_fn(verts, feats): |
| 451 | + out, _ = super(FIGConvUNetTimeConditionalRollout, self).forward( |
| 452 | + vertices=verts, features=feats |
| 453 | + ) |
| 454 | + return out |
| 455 | + |
| 456 | + if self.training: |
| 457 | + outf = ckpt( |
| 458 | + step_fn, |
| 459 | + vertices, |
| 460 | + features_t, |
| 461 | + use_reentrant=False, |
| 462 | + ).squeeze(0) # [N, 3] |
| 463 | + else: |
| 464 | + outf = step_fn(vertices, features_t).squeeze(0) # [N, 3] |
| 465 | + |
| 466 | + y_t = x + outf |
| 467 | + outputs.append(y_t) |
| 468 | + |
| 469 | + return torch.stack(outputs, dim=0) # [T, N, 3] |
0 commit comments