Skip to content

Add 2D wave equation FNO training recipe#1489

Open
gpartin wants to merge 1 commit intoNVIDIA:mainfrom
gpartin:feature/add-wave-equation-example
Open

Add 2D wave equation FNO training recipe#1489
gpartin wants to merge 1 commit intoNVIDIA:mainfrom
gpartin:feature/add-wave-equation-example

Conversation

@gpartin
Copy link

@gpartin gpartin commented Mar 11, 2026

Summary

This PR adds a complete Fourier Neural Operator training example for the 2D wave equation, filling the hyperbolic PDE gap in PhysicsNeMo's example suite.

Motivation

Current PhysicsNeMo examples focus on elliptic (Darcy flow) and parabolic (Navier-Stokes, heat) PDEs. The wave equation is a fundamental hyperbolic PDE critical for acoustics, seismology, and electromagnetic applications. This example provides a ready-to-use training recipe following the same patterns as \darcy_fno.

What's Included

File Purpose
\ rain_fno_wave.py\ Training script (Hydra config, distributed, checkpointing)
\wave_data.py\ On-the-fly data generator using leapfrog finite differences
\�alidator.py\ Side-by-side visualization (initial/truth/prediction/error)
\config.yaml\ 128x128 grid, 4-layer FNO, standard hyperparameters
\README.md\ Problem description and usage instructions
\
equirements.txt\ Minimal dependencies

Problem Setup

  • PDE: $\partial^2 u / \partial t^2 = c^2 \nabla^2 u$
  • Domain: $[0,1]^2$ with periodic boundaries
  • Input: Random Fourier initial wavefield (x,y,0)$
  • Output: Solution at target time (x,y,T)$
  • Solver: Leapfrog finite-difference with Taylor expansion first step

Design Decisions

  • On-the-fly generation: Follows the Darcy2D pattern — unlimited unique training samples
  • No external dependencies: Pure NumPy/PyTorch data generation, no Warp requirement
  • Standard FNO architecture: Same \physicsnemo.models.fno.FNO\ as other examples
  • Configurable physics: Wave speed, target time, and CFL number via Hydra config

Usage

\\�ash
cd examples/wave/wave_fno
pip install -r requirements.txt
python train_fno_wave.py
\\

Testing

Verified data generator produces correct wave evolution (energy-conserving leapfrog with CFL-stable parameters).

Add a complete Fourier Neural Operator example for the 2D wave equation,
filling the hyperbolic PDE gap in PhysicsNeMo examples.

- train_fno_wave.py: Training script following darcy_fno patterns
- wave_data.py: On-the-fly leapfrog data generator with periodic BCs
- validator.py: Side-by-side visualization (initial/truth/prediction/error)
- config.yaml: Hydra config for 128x128 grid, 4-layer FNO
- README.md: Problem description and usage instructions

The model learns u(x,y,0) -> u(x,y,T) using random Fourier initial
conditions. Data is generated in-situ, matching the Darcy2D pattern.
Copilot AI review requested due to automatic review settings March 11, 2026 18:58
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 11, 2026

Greptile Summary

This PR adds a complete FNO training example for the 2D wave equation, filling the hyperbolic PDE gap alongside the existing elliptic (Darcy) and parabolic (Navier-Stokes) examples. The implementation closely follows the established PhysicsNeMo FNO recipe — Hydra config, DistributedManager, StaticCapture* wrappers, checkpointing — and the leapfrog finite-difference solver is mathematically correct.

Key findings:

  • Dead code in wave_data.py: c2_ratio is precomputed on line 75 but never referenced; the leapfrog update computes dt**2 * wave_speed**2 * lap directly. This should either be used to optimise the inner loop (avoiding redundant / dx**2 division) or removed entirely.
  • Minor style issue in train_fno_wave.py: The checkpoint path uses an unnecessary f-string f"./checkpoints" with no interpolation.
  • The CPU-only, per-sample NumPy data generation will be the bottleneck at large batch sizes, but this matches the documented design intent ("Pure NumPy/PyTorch, no Warp requirement").

Important Files Changed

Filename Overview
examples/wave/wave_fno/wave_data.py Data generator for 2D wave equation; leapfrog integration is mathematically correct, but c2_ratio is computed and never used (dead code on line 75). CPU-only NumPy generation with per-sample Python loops may be slow at large batch sizes.
examples/wave/wave_fno/train_fno_wave.py Training script follows the established PhysicsNeMo FNO pattern. Minor: f-string without interpolation on path string (line 82). Validation accumulates a tensor sum which works correctly since StaticCaptureEvaluateNoGrad disables grad tracking.
examples/wave/wave_fno/validator.py Validation helper that plots initial condition, ground truth, prediction, and error side-by-side. Straightforward implementation; .detach() is correctly applied before .numpy() on the prediction.
examples/wave/wave_fno/config.yaml Hydra config covering architecture, scheduler, training, validation, and wave physics parameters. Values are reasonable for a 128x128 4-layer FNO example.
examples/wave/wave_fno/README.md Clear README describing the problem setup, prerequisites, and usage. References are relevant.
examples/wave/wave_fno/requirements.txt Minimal requirements file; only hydra-core and termcolor are listed as extra dependencies beyond the main package.

Last reviewed commit: 2f511db

y = np.linspace(0, 1, resolution, endpoint=False)
xx, yy = np.meshgrid(x, y, indexing="ij")

c2_ratio = (wave_speed * dt / dx) ** 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused precomputed variable (dead code)

c2_ratio is computed here but is never referenced anywhere in the function. The leapfrog update on line 115 computes dt**2 * wave_speed**2 * lap directly, which is equivalent to c2_ratio * dx**2 * lap — but the code never uses c2_ratio. This variable was likely intended to optimise the inner-loop computation (pre-dividing out the dx**2 denominator to avoid recomputing it each iteration), but the optimisation was never wired up. Either use it to simplify the inner loop or remove it:

Suggested change
c2_ratio = (wave_speed * dt / dx) ** 2
c2_ratio = (wave_speed * dt / dx) ** 2 # precomputed for inner loop

Or, if the intent was to use it in the leapfrog steps, the inner loop Laplacian should be computed without the /dx**2 division and then multiplied by c2_ratio:

# Optimised leapfrog using c2_ratio
raw_lap = (
    np.roll(u_curr, 1, axis=0) + np.roll(u_curr, -1, axis=0)
    + np.roll(u_curr, 1, axis=1) + np.roll(u_curr, -1, axis=1)
    - 4.0 * u_curr
)
u_next = 2.0 * u_curr - u_prev + c2_ratio * raw_lap

If neither usage is intended, remove line 75 entirely.

Comment on lines +100 to +107
u_prev = u.copy()
# Taylor expansion for first step: u(dt) = u(0) + 0.5*dt^2*c^2*laplacian(u)
lap = (
np.roll(u, 1, axis=0) + np.roll(u, -1, axis=0)
+ np.roll(u, 1, axis=1) + np.roll(u, -1, axis=1)
- 4.0 * u
) / dx**2
u_curr = u + 0.5 * dt**2 * wave_speed**2 * lap
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First leapfrog step computes one fewer step than intended

After the Taylor-expansion bootstrap step, u_curr holds the field at t = dt. The loop then runs n_steps - 1 additional steps, advancing to t = n_steps * dt = T. That means only n_steps total time advances are taken (1 bootstrap + n_steps - 1 leapfrog), ending correctly at T.

However, u_prev is set to u.copy() (i.e., u(0)) before the Taylor-expansion step. The leapfrog formula in the loop is u_next = 2*u_curr - u_prev + ... which on the first iteration computes u(2dt) = 2*u(dt) - u(0) + .... This is the correct 3-point leapfrog recurrence and gives the right answer.

No action needed — the integration is mathematically correct — but the comment on line 101 is slightly misleading. It says "Taylor expansion for first step", but the result (0.5 * dt**2 * c**2 * lap) is exactly what you'd get from the standard ghost-point initialisation u(-dt) = u(0) - dt*v(0) + 0.5*dt^2*c^2*lap. Clarifying this in the comment would make it easier for readers to verify correctness:

Suggested change
u_prev = u.copy()
# Taylor expansion for first step: u(dt) = u(0) + 0.5*dt^2*c^2*laplacian(u)
lap = (
np.roll(u, 1, axis=0) + np.roll(u, -1, axis=0)
+ np.roll(u, 1, axis=1) + np.roll(u, -1, axis=1)
- 4.0 * u
) / dx**2
u_curr = u + 0.5 * dt**2 * wave_speed**2 * lap
# Bootstrap first step from zero initial velocity using second-order accurate
# Taylor expansion: u(dt) = u(0) + dt*v(0) + 0.5*dt^2*c^2*lap(u(0))
# = u(0) + 0.5*dt^2*c^2*lap(u(0)) [since v(0)=0]
# (Equivalent to leapfrog ghost-point u(-dt)=u(0)-dt*v+0.5*dt^2*c^2*lap)

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

validator = WaveValidator(loss_fun=MSELoss(reduction="mean"))

ckpt_args = {
"path": f"./checkpoints",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

F-string with no interpolation

f"./checkpoints" is an f-string that contains no placeholders. This has no effect but is misleading — it looks like a dynamic path was intended. Use a plain string literal instead:

Suggested change
"path": f"./checkpoints",
"path": "./checkpoints",

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a new PhysicsNeMo example that trains a Fourier Neural Operator (FNO) to learn the solution operator for the 2D wave equation, including on-the-fly data generation and validation visualization, mirroring the existing Darcy FNO recipe structure.

Changes:

  • Introduces a NumPy-based leapfrog finite-difference generator to produce unlimited (u(x,y,0) → u(x,y,T)) pairs on demand.
  • Adds a training script using Hydra + PhysicsNeMo distributed utilities, checkpointing, and LaunchLogger integration.
  • Adds a validator that logs side-by-side plots of initial condition, truth, prediction, and error, plus example config/README/requirements.

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 7 comments.

Show a summary per file
File Description
examples/wave/wave_fno/wave_data.py On-the-fly wave equation batch generator + iterable loader
examples/wave/wave_fno/validator.py Validation loss + visualization logging
examples/wave/wave_fno/train_fno_wave.py End-to-end training recipe (Hydra, distributed, ckpt, validation)
examples/wave/wave_fno/config.yaml Default hyperparameters + wave setup
examples/wave/wave_fno/README.md Problem description and usage instructions
examples/wave/wave_fno/requirements.txt Minimal example-specific Python deps

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

You can also share your feedback on Copilot code review. Take the survey.

Comment on lines +158 to +160
normaliser: dict | None = None,
device: str = "cuda",
):
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WaveDataLoader also defaults device to "cuda", which will fail in CPU-only environments when used outside the training script. Consider defaulting to "cpu"/None (and documenting/handling auto-device selection) and typing this as str | torch.device for consistency with how dist.device is passed in from the trainer.

Copilot uses AI. Check for mistakes.
Comment on lines +173 to +182
def __next__(self) -> dict[str, torch.Tensor]:
initial, target = generate_wave_batch(
batch_size=self.batch_size,
resolution=self.resolution,
wave_speed=self.wave_speed,
target_time=self.target_time,
nr_modes=self.nr_modes,
cfl=self.cfl,
device=self.device,
)
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WaveDataLoader always calls generate_wave_batch(...) without a seed, which makes runs non-reproducible and can also lead to duplicated random streams across distributed ranks. Consider adding a seed/base_seed to WaveDataLoader, maintaining a persistent RNG (or incrementing seed per batch), and offsetting by DistributedManager().rank when running distributed.

Copilot uses AI. Check for mistakes.
cfl=cfg.wave.cfl,
device=dist.device,
)
validator = WaveValidator(loss_fun=MSELoss(reduction="mean"))
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

loss_fun is already defined above; here a new MSELoss instance is constructed instead of reusing it. Reusing loss_fun keeps the configuration consistent (e.g., if reduction/weighting ever changes) and avoids duplicate objects.

Suggested change
validator = WaveValidator(loss_fun=MSELoss(reduction="mean"))
validator = WaveValidator(loss_fun=loss_fun)

Copilot uses AI. Check for mistakes.
Comment on lines +65 to +69
dx = 1.0 / resolution
dt = cfl * dx / wave_speed
n_steps = int(np.ceil(target_time / dt))
dt = target_time / n_steps # adjust for exact target time

Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dt = cfl * dx / wave_speed, n_steps = ceil(target_time / dt), and dt = target_time / n_steps will raise on invalid inputs (e.g., wave_speed <= 0 division by zero, or target_time <= 0 producing n_steps == 0). Add explicit validation for resolution > 0, batch_size > 0, wave_speed > 0, cfl > 0, and target_time > 0 (and fail fast with a clear error).

Copilot uses AI. Check for mistakes.
Comment on lines +75 to +76
c2_ratio = (wave_speed * dt / dx) ** 2

Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

c2_ratio is computed but never used. Please remove it or use it (e.g., for an explicit CFL/stability check) to avoid dead code.

Suggested change
c2_ratio = (wave_speed * dt / dx) ** 2

Copilot uses AI. Check for mistakes.
Comment on lines +80 to +120
for b in range(batch_size):
# Random superposition of Fourier modes
u = np.zeros((resolution, resolution), dtype=np.float64)
for _ in range(nr_modes):
kx = rng.integers(-nr_modes, nr_modes + 1)
ky = rng.integers(-nr_modes, nr_modes + 1)
amp = rng.standard_normal()
phase = rng.uniform(0, 2 * np.pi)
u += amp * np.sin(
2 * np.pi * (kx * xx + ky * yy) + phase
)

# Normalize to unit variance
std = np.std(u)
if std > 1e-10:
u /= std

u0_all[b] = u.astype(np.float32)

# Leapfrog time integration with zero initial velocity
u_prev = u.copy()
# Taylor expansion for first step: u(dt) = u(0) + 0.5*dt^2*c^2*laplacian(u)
lap = (
np.roll(u, 1, axis=0) + np.roll(u, -1, axis=0)
+ np.roll(u, 1, axis=1) + np.roll(u, -1, axis=1)
- 4.0 * u
) / dx**2
u_curr = u + 0.5 * dt**2 * wave_speed**2 * lap

for _ in range(n_steps - 1):
lap = (
np.roll(u_curr, 1, axis=0) + np.roll(u_curr, -1, axis=0)
+ np.roll(u_curr, 1, axis=1) + np.roll(u_curr, -1, axis=1)
- 4.0 * u_curr
) / dx**2
u_next = 2.0 * u_curr - u_prev + dt**2 * wave_speed**2 * lap
u_prev = u_curr
u_curr = u_next

uT_all[b] = u_curr.astype(np.float32)

Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This batch generator is likely to be a major training bottleneck: it loops over each sample in Python and then performs n_steps iterations of np.roll-based Laplacians per sample. Consider vectorizing over the batch dimension and/or implementing the time stepping in torch on-device (or using an FFT-based Laplacian for periodic BCs) to avoid spending most of the wall time in NumPy on the host.

Suggested change
for b in range(batch_size):
# Random superposition of Fourier modes
u = np.zeros((resolution, resolution), dtype=np.float64)
for _ in range(nr_modes):
kx = rng.integers(-nr_modes, nr_modes + 1)
ky = rng.integers(-nr_modes, nr_modes + 1)
amp = rng.standard_normal()
phase = rng.uniform(0, 2 * np.pi)
u += amp * np.sin(
2 * np.pi * (kx * xx + ky * yy) + phase
)
# Normalize to unit variance
std = np.std(u)
if std > 1e-10:
u /= std
u0_all[b] = u.astype(np.float32)
# Leapfrog time integration with zero initial velocity
u_prev = u.copy()
# Taylor expansion for first step: u(dt) = u(0) + 0.5*dt^2*c^2*laplacian(u)
lap = (
np.roll(u, 1, axis=0) + np.roll(u, -1, axis=0)
+ np.roll(u, 1, axis=1) + np.roll(u, -1, axis=1)
- 4.0 * u
) / dx**2
u_curr = u + 0.5 * dt**2 * wave_speed**2 * lap
for _ in range(n_steps - 1):
lap = (
np.roll(u_curr, 1, axis=0) + np.roll(u_curr, -1, axis=0)
+ np.roll(u_curr, 1, axis=1) + np.roll(u_curr, -1, axis=1)
- 4.0 * u_curr
) / dx**2
u_next = 2.0 * u_curr - u_prev + dt**2 * wave_speed**2 * lap
u_prev = u_curr
u_curr = u_next
uT_all[b] = u_curr.astype(np.float32)
# -------------------------------
# Batched random superposition of Fourier modes
# -------------------------------
# Draw mode parameters for the whole batch at once: shape (batch_size, nr_modes)
kx = rng.integers(-nr_modes, nr_modes + 1, size=(batch_size, nr_modes))
ky = rng.integers(-nr_modes, nr_modes + 1, size=(batch_size, nr_modes))
amp = rng.standard_normal(size=(batch_size, nr_modes))
phase = rng.uniform(0.0, 2 * np.pi, size=(batch_size, nr_modes))
# Broadcast over spatial grid to build initial fields: u.shape = (batch_size, N, N)
xx_b = xx[None, None, :, :] # (1, 1, N, N)
yy_b = yy[None, None, :, :] # (1, 1, N, N)
kx_b = kx[:, :, None, None] # (B, M, 1, 1)
ky_b = ky[:, :, None, None] # (B, M, 1, 1)
amp_b = amp[:, :, None, None] # (B, M, 1, 1)
phase_b = phase[:, :, None, None] # (B, M, 1, 1)
arg = 2 * np.pi * (kx_b * xx_b + ky_b * yy_b) + phase_b
u = np.sum(amp_b * np.sin(arg), axis=1) # (batch_size, N, N)
# Normalize each sample to unit variance (if std is not too small)
std = np.std(u, axis=(1, 2), keepdims=True) # (batch_size, 1, 1)
# Avoid division by very small std by using 1.0 where std <= 1e-10
safe_std = np.where(std > 1e-10, std, 1.0)
u = u / safe_std
u0_all[:] = u.astype(np.float32)
# -------------------------------
# Batched leapfrog time integration with zero initial velocity
# -------------------------------
# Use c2_ratio = (c * dt / dx)^2 and the undivided Laplacian stencil
u_prev = u.copy() # (batch_size, N, N), at time t = 0
# Taylor expansion for first step: u(dt) = u(0) + 0.5 * c^2 * dt^2 * laplacian(u)
lap_core = (
np.roll(u_prev, 1, axis=1)
+ np.roll(u_prev, -1, axis=1)
+ np.roll(u_prev, 1, axis=2)
+ np.roll(u_prev, -1, axis=2)
- 4.0 * u_prev
)
u_curr = u_prev + 0.5 * c2_ratio * lap_core
for _ in range(n_steps - 1):
lap_core = (
np.roll(u_curr, 1, axis=1)
+ np.roll(u_curr, -1, axis=1)
+ np.roll(u_curr, 1, axis=2)
+ np.roll(u_curr, -1, axis=2)
- 4.0 * u_curr
)
u_next = 2.0 * u_curr - u_prev + c2_ratio * lap_core
u_prev = u_curr
u_curr = u_next
uT_all[:] = u_curr.astype(np.float32)

Copilot uses AI. Check for mistakes.
Comment on lines +35 to +37
device: str = "cuda",
seed: int | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

device is annotated as str and defaults to "cuda". With the defaults, calling this function on a CPU-only machine will throw when .to("cuda") is executed. Consider (1) widening the type to str | torch.device, and (2) defaulting to None / "cpu" / an auto-selected device so the example can run without requiring CUDA.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants