Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 118 additions & 15 deletions CODING_STANDARDS/FUNCTIONAL_APIS.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ This document is structured in two main sections:
| [`FNC-002`](#fnc-002-file-layout-for-functionals) | File layout for functionals | Adding or refactoring functional files |
| [`FNC-003`](#fnc-003-registration-and-dispatch-rules) | Registration and dispatch rules | Registering implementations |
| [`FNC-004`](#fnc-004-optional-dependency-handling) | Optional dependency handling | Using optional backends |
| [`FNC-005`](#fnc-005-benchmarking-hooks) | Benchmarking hooks | Implementing `make_inputs`/`compare` |
| [`FNC-005`](#fnc-005-benchmarking-hooks) | Benchmarking hooks | Implementing `make_inputs_forward`/`make_inputs_backward`/`compare_forward` |
| [`FNC-006`](#fnc-006-testing-functionals) | Testing functionals | Adding functional tests |
| [`FNC-007`](#fnc-007-benchmark-registry) | Benchmark registry | Adding a functional to ASV |

Expand All @@ -72,7 +72,8 @@ This document is structured in two main sections:

All functionals must be implemented with `FunctionSpec`, even if only a single
implementation exists. This ensures the operation participates in validation
and benchmarking via `make_inputs` and `compare`.
and benchmarking via input generators and `compare_forward` (and
`compare_backward` where needed).

**Rationale:**

Expand Down Expand Up @@ -148,14 +149,41 @@ class Identity(FunctionSpec):
return identity_torch(x)

@classmethod
def make_inputs(cls, device: torch.device | str = "cpu"):
def make_inputs_forward(cls, device: torch.device | str = "cpu"):
device = torch.device(device)
yield ("small", (torch.randn(1024, device=device),), {})
yield ("medium", (torch.randn(4096, device=device),), {})
yield ("large", (torch.randn(16384, device=device),), {})

@classmethod
def compare(cls, output: torch.Tensor, reference: torch.Tensor) -> None:
def make_inputs_backward(cls, device: torch.device | str = "cpu"):
device = torch.device(device)
yield (
"small-bwd",
(torch.randn(1024, device=device, requires_grad=True),),
{},
)
yield (
"medium-bwd",
(torch.randn(4096, device=device, requires_grad=True),),
{},
)
yield (
"large-bwd",
(torch.randn(16384, device=device, requires_grad=True),),
{},
)

@classmethod
def compare_forward(
cls, output: torch.Tensor, reference: torch.Tensor
) -> None:
torch.testing.assert_close(output, reference)

@classmethod
def compare_backward(
cls, output: torch.Tensor, reference: torch.Tensor
) -> None:
torch.testing.assert_close(output, reference)

identity = Identity.make_function("identity")
Expand Down Expand Up @@ -210,6 +238,14 @@ __all__ = ["knn"]
`physicsnemo/nn/functional/<name>/`.
- Keep each backend in its own module (e.g., `_torch_impl.py`).
- Keep shared helpers in `utils.py`.
- For complex Warp backends, prefer a dedicated `_warp_impl/` package with:
- `op.py` for torch custom-op registration and validation
- `launch_forward.py` for forward launch dispatch
- `launch_backward.py` for backward launch dispatch
- `_kernels/` with one kernel per file
- `utils.py` for shared Warp constants/functions
- Keep `launch_forward.py` and `launch_backward.py` as the only launch
surfaces; avoid extra launch helper modules unless there is a strong reason.

**Rationale:**

Expand All @@ -228,6 +264,21 @@ physicsnemo/nn/functional/knn/
utils.py
```

```text
physicsnemo/nn/functional/interpolation/grid_to_point_interpolation/
grid_to_point_interpolation.py
_torch_impl.py
_warp_impl/
__init__.py
op.py
launch_forward.py
launch_backward.py
_kernels/
forward_3d_stride2.py
backward_3d_stride2.py
utils.py
```

**Anti-pattern:**

```text
Expand Down Expand Up @@ -308,11 +359,20 @@ import missing_dep # raises at import time

**Description:**

Implement `make_inputs` and `compare` for every functional. `make_inputs` should
yield labeled inputs ordered from smaller to larger cases. Labels do not have to
be exactly "small/medium/large", and you can provide more than three cases.
`compare` should validate output consistency. Labels are used for benchmark
plots and summaries.
Implement `make_inputs_forward` for every functional so it can be benchmarked.
Implement `compare_forward` when a functional has multiple implementations and
needs cross-backend parity checks in tests.

Implement `make_inputs_backward` only for functionals with a meaningful
backward pass (for example differentiable functionals). Implement
`compare_backward` when a functional has backward support and multiple
implementations that need backward parity checks.

Input generators should yield labeled inputs ordered from smaller to larger
cases. Labels do not have to be exactly "small/medium/large", and you can
provide more than three cases. Compare hooks should validate output
consistency where implemented. Labels are used for benchmark plots and
summaries.

**Rationale:**

Expand All @@ -323,17 +383,30 @@ backends.

```python
@classmethod
def make_inputs(cls, device="cpu"):
def make_inputs_forward(cls, device="cpu"):
yield ("small", (torch.randn(1024, device=device),), {})
yield ("medium", (torch.randn(4096, device=device),), {})
yield ("large", (torch.randn(16384, device=device),), {})

@classmethod
def make_inputs_backward(cls, device="cpu"):
x = torch.randn(4096, device=device, requires_grad=True)
yield ("medium", (x,), {})

@classmethod
def compare_forward(cls, output, reference):
torch.testing.assert_close(output, reference)

@classmethod
def compare_backward(cls, output, reference):
torch.testing.assert_close(output, reference)
```

**Anti-pattern:**

```python
@classmethod
def make_inputs(cls, device="cpu"):
def make_inputs_forward(cls, device="cpu"):
pass
```

Expand All @@ -346,6 +419,23 @@ def make_inputs(cls, device="cpu"):
Add tests under `test/nn/functional/` to validate selection, optional
dependencies, and output correctness.

Use a consistent test layout when possible. This is **highly recommended** for
readability and review speed, but it is **not strictly required** when a
functional needs a different shape.

Suggested naming/structure:

1. Backend/reference correctness:
- `test_<functional_name>_<implementation_name>`
2. Cross-backend parity:
- `test_<functional_name>_backend_forward_parity`
- `test_<functional_name>_backend_backward_parity` (only for differentiable ops)
3. Deprecation + validation paths:
- `test_<functional_name>_error_handeling`

Where possible, keep all backend parity checks in one functional test file and
use the functional's `compare_forward`/`compare_backward` hooks for consistency.

**Rationale:**

Functional APIs are public entry points and need coverage for both the API and
Expand All @@ -354,8 +444,20 @@ backend behavior.
**Example:**

```python
def test_knn_cpu():
indices, distances = knn(points, queries, k=4)
def test_grid_to_point_interpolation_torch():
...

def test_grid_to_point_interpolation_warp():
...

def test_grid_to_point_interpolation_backend_forward_parity():
...

def test_grid_to_point_interpolation_backend_backward_parity():
...

def test_grid_to_point_interpolation_error_handeling():
...
```

**Anti-pattern:**
Expand All @@ -372,7 +474,8 @@ def test_knn_cpu():

Functionals that should be benchmarked must be added to
`benchmarks/physicsnemo/nn/functional/registry.py`. Only add a functional once
its `make_inputs` implementation yields labeled inputs.
its input generators (`make_inputs_forward`, and optionally
`make_inputs_backward`) yield labeled inputs.

**Rationale:**

Expand All @@ -392,6 +495,6 @@ FUNCTIONAL_SPECS = (KNN, RadiusSearch)
**Anti-pattern:**

```python
# Adding a functional before make_inputs is implemented.
# Adding a functional before input generators are implemented.
FUNCTIONAL_SPECS = (MyFunctionalWithoutInputs,)
```
25 changes: 17 additions & 8 deletions benchmarks/physicsnemo/nn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@ This directory contains ASV benchmarks for `physicsnemo.nn`.
For functionals, the benchmark flow is intentionally simple:

1. Implement or update the functional `FunctionSpec`.
2. Add representative `make_inputs(device=...)` cases to that `FunctionSpec`.
3. Register the `FunctionSpec` in `benchmarks/physicsnemo/nn/functional/registry.py`.
4. Run ASV and regenerate plots.
2. Add representative `make_inputs_forward(device=...)` cases.
3. Add `make_inputs_backward(device=...)` when backward benchmarking is needed.
4. Register the `FunctionSpec` in `benchmarks/physicsnemo/nn/functional/registry.py`.
5. Run ASV and regenerate plots.

## Where to read more

Expand All @@ -27,9 +28,9 @@ For functionals, the benchmark flow is intentionally simple:

## Example functionals to copy

- `physicsnemo/nn/functional/interpolation/interpolation.py`
- `physicsnemo/nn/functional/radius_search/radius_search.py`
- `physicsnemo/nn/functional/knn/knn.py`
- `physicsnemo/nn/functional/interpolation/grid_to_point_interpolation/grid_to_point_interpolation.py`
- `physicsnemo/nn/functional/neighbors/radius_search/radius_search.py`
- `physicsnemo/nn/functional/neighbors/knn/knn.py`

## Common commands

Expand All @@ -42,9 +43,17 @@ Run benchmarks (repo root):
Run only selected functionals while iterating:

```bash
PHYSICSNEMO_ASV_FUNCTIONALS=Interpolation,RadiusSearch ./benchmarks/run_benchmarks.sh
PHYSICSNEMO_ASV_FUNCTIONALS=GridToPointInterpolation,RadiusSearch ./benchmarks/run_benchmarks.sh
```

Run only selected benchmark phases:

```bash
PHYSICSNEMO_ASV_PHASES=forward ./benchmarks/run_benchmarks.sh
PHYSICSNEMO_ASV_PHASES=forward,backward ./benchmarks/run_benchmarks.sh
```

Plots are written under:

- `docs/nn/functional/<functional_name>/benchmark.png`
- `docs/nn/functional/<category>/<functional_name>/benchmark_forward.png` (forward)
- `docs/nn/functional/<category>/<functional_name>/benchmark_backward.png` (backward)
102 changes: 102 additions & 0 deletions benchmarks/physicsnemo/nn/functional/_spec_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code in the benchmarks is still a bit junky.

# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Shared helpers for functional ASV benchmark scripts."""

from __future__ import annotations

from typing import Any

import torch

from physicsnemo.core.function_spec import FunctionSpec

PHASE_ORDER = ("forward", "backward")


def supports_backward_inputs(spec: type) -> bool:
"""Return True when a spec overrides backward input generation."""

return spec.make_inputs_backward.__func__ is not FunctionSpec.make_inputs_backward


def _metadata_case_labels(spec: type) -> list[str]:
"""Return benchmark case labels from optional spec metadata."""

benchmark_cases = getattr(spec, "_BENCHMARK_CASES", None)
if isinstance(benchmark_cases, (list, tuple)):
labels = [
case[0]
for case in benchmark_cases
if isinstance(case, tuple) and case and isinstance(case[0], str)
]
if labels:
return labels

benchmark_cases_fn = getattr(spec, "_benchmark_cases", None)
if callable(benchmark_cases_fn):
labels = [
case[0]
for case in benchmark_cases_fn()
if isinstance(case, tuple) and case and isinstance(case[0], str)
]
if labels:
return labels

return []


def case_labels(spec: type, phase: str, device: torch.device | str) -> list[str]:
"""Resolve labeled benchmark cases for one phase."""

if phase not in PHASE_ORDER:
raise ValueError(f"Unsupported benchmark phase: {phase}")
if phase == "backward" and not supports_backward_inputs(spec):
return []

labels = _metadata_case_labels(spec)
if labels:
return labels

if phase == "forward":
return [label for label, _, _ in spec.make_inputs_forward(device=device)]
return [label for label, _, _ in spec.make_inputs_backward(device=device)]


def case_by_index(
spec: type,
phase: str,
case_index: int,
device: torch.device | str,
) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
"""Materialize one case from the phase-specific input generator."""

if phase == "forward":
case_iter = spec.make_inputs_forward(device=device)
elif phase == "backward":
case_iter = spec.make_inputs_backward(device=device)
else:
raise ValueError(f"Unsupported benchmark phase: {phase}")

for index, case in enumerate(case_iter):
if index == case_index:
return case
raise IndexError(
f"Case index {case_index} out of range for {spec.__name__} phase={phase}"
)


__all__ = ["PHASE_ORDER", "supports_backward_inputs", "case_labels", "case_by_index"]
Loading