Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 148 additions & 56 deletions CODING_STANDARDS/FUNCTIONAL_APIS.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ This document is structured in two main sections:
| [`FNC-002`](#fnc-002-file-layout-for-functionals) | File layout for functionals | Adding or refactoring functional files |
| [`FNC-003`](#fnc-003-registration-and-dispatch-rules) | Registration and dispatch rules | Registering implementations |
| [`FNC-004`](#fnc-004-optional-dependency-handling) | Optional dependency handling | Using optional backends |
| [`FNC-005`](#fnc-005-benchmarking-hooks) | Benchmarking hooks | Implementing `make_inputs`/`compare` |
| [`FNC-005`](#fnc-005-benchmarking-hooks) | Benchmarking hooks | Implementing `make_inputs_forward`/`make_inputs_backward`/`compare_forward` |
| [`FNC-006`](#fnc-006-testing-functionals) | Testing functionals | Adding functional tests |
| [`FNC-007`](#fnc-007-benchmark-registry) | Benchmark registry | Adding a functional to ASV |

Expand All @@ -72,7 +72,8 @@ This document is structured in two main sections:

All functionals must be implemented with `FunctionSpec`, even if only a single
implementation exists. This ensures the operation participates in validation
and benchmarking via `make_inputs` and `compare`.
and benchmarking via input generators and `compare_forward` (and
`compare_backward` where needed).

**Rationale:**

Expand All @@ -82,52 +83,41 @@ selection, benchmarking and verification across the codebase.
**Example:**

```python
import importlib
import torch
import warp as wp

from physicsnemo.core.function_spec import FunctionSpec
from physicsnemo.core.version_check import check_version_spec

WARP_AVAILABLE = check_version_spec("warp", "0.6.0", hard_fail=False)

if WARP_AVAILABLE:
wp = importlib.import_module("warp")
wp.init()
wp.config.quiet = True

@wp.kernel
def _identity_kernel(
x: wp.array(dtype=wp.float32),
y: wp.array(dtype=wp.float32),
):
i = wp.tid()
y[i] = x[i]

@torch.library.custom_op("physicsnemo::identity_warp", mutates_args=())
def identity_impl(x: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
device, stream = FunctionSpec.warp_launch_context(x)
wp_x = wp.from_torch(x, dtype=wp.float32, return_ctype=True)
wp_y = wp.from_torch(out, dtype=wp.float32, return_ctype=True)
with wp.ScopedStream(stream):
wp.launch(
kernel=_identity_kernel,
dim=x.numel(),
inputs=[wp_x, wp_y],
device=device,
stream=stream,
)
return out

@identity_impl.register_fake
def identity_impl_fake(x: torch.Tensor) -> torch.Tensor:
return torch.empty_like(x)
else:

def identity_impl(*args, **kwargs) -> torch.Tensor:
raise ImportError(
"warp>=0.6.0 is required for the Warp identity implementation"
wp.init()
wp.config.quiet = True

@wp.kernel
def _identity_kernel(
x: wp.array(dtype=wp.float32),
y: wp.array(dtype=wp.float32),
):
i = wp.tid()
y[i] = x[i]

@torch.library.custom_op("physicsnemo::identity_warp", mutates_args=())
def identity_impl(x: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
device, stream = FunctionSpec.warp_launch_context(x)
wp_x = wp.from_torch(x, dtype=wp.float32, return_ctype=True)
wp_y = wp.from_torch(out, dtype=wp.float32, return_ctype=True)
with wp.ScopedStream(stream):
wp.launch(
kernel=_identity_kernel,
dim=x.numel(),
inputs=[wp_x, wp_y],
device=device,
stream=stream,
)
return out

@identity_impl.register_fake
def identity_impl_fake(x: torch.Tensor) -> torch.Tensor:
return torch.empty_like(x)

def identity_torch(x: torch.Tensor) -> torch.Tensor:
return x.clone()
Expand All @@ -148,14 +138,41 @@ class Identity(FunctionSpec):
return identity_torch(x)

@classmethod
def make_inputs(cls, device: torch.device | str = "cpu"):
def make_inputs_forward(cls, device: torch.device | str = "cpu"):
device = torch.device(device)
yield ("small", (torch.randn(1024, device=device),), {})
yield ("medium", (torch.randn(4096, device=device),), {})
yield ("large", (torch.randn(16384, device=device),), {})

@classmethod
def compare(cls, output: torch.Tensor, reference: torch.Tensor) -> None:
def make_inputs_backward(cls, device: torch.device | str = "cpu"):
device = torch.device(device)
yield (
"small-bwd",
(torch.randn(1024, device=device, requires_grad=True),),
{},
)
yield (
"medium-bwd",
(torch.randn(4096, device=device, requires_grad=True),),
{},
)
yield (
"large-bwd",
(torch.randn(16384, device=device, requires_grad=True),),
{},
)

@classmethod
def compare_forward(
cls, output: torch.Tensor, reference: torch.Tensor
) -> None:
torch.testing.assert_close(output, reference)

@classmethod
def compare_backward(
cls, output: torch.Tensor, reference: torch.Tensor
) -> None:
torch.testing.assert_close(output, reference)

identity = Identity.make_function("identity")
Expand Down Expand Up @@ -210,6 +227,14 @@ __all__ = ["knn"]
`physicsnemo/nn/functional/<name>/`.
- Keep each backend in its own module (e.g., `_torch_impl.py`).
- Keep shared helpers in `utils.py`.
- For complex Warp backends, prefer a dedicated `_warp_impl/` package with:
- `op.py` for torch custom-op registration and validation
- `launch_forward.py` for forward launch dispatch
- `launch_backward.py` for backward launch dispatch
- `_kernels/` with one kernel per file
- `utils.py` for shared Warp constants/functions
- Keep `launch_forward.py` and `launch_backward.py` as the only launch
surfaces; avoid extra launch helper modules unless there is a strong reason.

**Rationale:**

Expand All @@ -228,6 +253,21 @@ physicsnemo/nn/functional/knn/
utils.py
```

```text
physicsnemo/nn/functional/interpolation/grid_to_point_interpolation/
grid_to_point_interpolation.py
_torch_impl.py
_warp_impl/
__init__.py
op.py
launch_forward.py
launch_backward.py
_kernels/
forward_3d_stride2.py
backward_3d_stride2.py
utils.py
```

**Anti-pattern:**

```text
Expand Down Expand Up @@ -308,11 +348,20 @@ import missing_dep # raises at import time

**Description:**

Implement `make_inputs` and `compare` for every functional. `make_inputs` should
yield labeled inputs ordered from smaller to larger cases. Labels do not have to
be exactly "small/medium/large", and you can provide more than three cases.
`compare` should validate output consistency. Labels are used for benchmark
plots and summaries.
Implement `make_inputs_forward` for every functional so it can be benchmarked.
Implement `compare_forward` when a functional has multiple implementations and
needs cross-backend parity checks in tests.

Implement `make_inputs_backward` only for functionals with a meaningful
backward pass (for example differentiable functionals). Implement
`compare_backward` when a functional has backward support and multiple
implementations that need backward parity checks.

Input generators should yield labeled inputs ordered from smaller to larger
cases. Labels do not have to be exactly "small/medium/large", and you can
provide more than three cases. Compare hooks should validate output
consistency where implemented. Labels are used for benchmark plots and
summaries.

**Rationale:**

Expand All @@ -323,17 +372,30 @@ backends.

```python
@classmethod
def make_inputs(cls, device="cpu"):
def make_inputs_forward(cls, device="cpu"):
yield ("small", (torch.randn(1024, device=device),), {})
yield ("medium", (torch.randn(4096, device=device),), {})
yield ("large", (torch.randn(16384, device=device),), {})

@classmethod
def make_inputs_backward(cls, device="cpu"):
x = torch.randn(4096, device=device, requires_grad=True)
yield ("medium", (x,), {})

@classmethod
def compare_forward(cls, output, reference):
torch.testing.assert_close(output, reference)

@classmethod
def compare_backward(cls, output, reference):
torch.testing.assert_close(output, reference)
```

**Anti-pattern:**

```python
@classmethod
def make_inputs(cls, device="cpu"):
def make_inputs_forward(cls, device="cpu"):
pass
```

Expand All @@ -346,6 +408,23 @@ def make_inputs(cls, device="cpu"):
Add tests under `test/nn/functional/` to validate selection, optional
dependencies, and output correctness.

Use a consistent test layout when possible. This is **highly recommended** for
readability and review speed, but it is **not strictly required** when a
functional needs a different shape.

Suggested naming/structure:

1. Backend/reference correctness:
- `test_<functional_name>_<implementation_name>`
2. Cross-backend parity:
- `test_<functional_name>_backend_forward_parity`
- `test_<functional_name>_backend_backward_parity` (only for differentiable ops)
3. Deprecation + validation paths:
- `test_<functional_name>_error_handling`

Where possible, keep all backend parity checks in one functional test file and
use the functional's `compare_forward`/`compare_backward` hooks for consistency.

**Rationale:**

Functional APIs are public entry points and need coverage for both the API and
Expand All @@ -354,8 +433,20 @@ backend behavior.
**Example:**

```python
def test_knn_cpu():
indices, distances = knn(points, queries, k=4)
def test_grid_to_point_interpolation_torch():
...

def test_grid_to_point_interpolation_warp():
...

def test_grid_to_point_interpolation_backend_forward_parity():
...

def test_grid_to_point_interpolation_backend_backward_parity():
...

def test_grid_to_point_interpolation_error_handling():
...
```

**Anti-pattern:**
Expand All @@ -372,7 +463,8 @@ def test_knn_cpu():

Functionals that should be benchmarked must be added to
`benchmarks/physicsnemo/nn/functional/registry.py`. Only add a functional once
its `make_inputs` implementation yields labeled inputs.
its input generators (`make_inputs_forward`, and optionally
`make_inputs_backward`) yield labeled inputs.

**Rationale:**

Expand All @@ -392,6 +484,6 @@ FUNCTIONAL_SPECS = (KNN, RadiusSearch)
**Anti-pattern:**

```python
# Adding a functional before make_inputs is implemented.
# Adding a functional before input generators are implemented.
FUNCTIONAL_SPECS = (MyFunctionalWithoutInputs,)
```
25 changes: 17 additions & 8 deletions benchmarks/physicsnemo/nn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@ This directory contains ASV benchmarks for `physicsnemo.nn`.
For functionals, the benchmark flow is intentionally simple:

1. Implement or update the functional `FunctionSpec`.
2. Add representative `make_inputs(device=...)` cases to that `FunctionSpec`.
3. Register the `FunctionSpec` in `benchmarks/physicsnemo/nn/functional/registry.py`.
4. Run ASV and regenerate plots.
2. Add representative `make_inputs_forward(device=...)` cases.
3. Add `make_inputs_backward(device=...)` when backward benchmarking is needed.
4. Register the `FunctionSpec` in `benchmarks/physicsnemo/nn/functional/registry.py`.
5. Run ASV and regenerate plots.

## Where to read more

Expand All @@ -27,9 +28,9 @@ For functionals, the benchmark flow is intentionally simple:

## Example functionals to copy

- `physicsnemo/nn/functional/interpolation/interpolation.py`
- `physicsnemo/nn/functional/radius_search/radius_search.py`
- `physicsnemo/nn/functional/knn/knn.py`
- `physicsnemo/nn/functional/interpolation/grid_to_point_interpolation/grid_to_point_interpolation.py`
- `physicsnemo/nn/functional/neighbors/radius_search/radius_search.py`
- `physicsnemo/nn/functional/neighbors/knn/knn.py`

## Common commands

Expand All @@ -42,9 +43,17 @@ Run benchmarks (repo root):
Run only selected functionals while iterating:

```bash
PHYSICSNEMO_ASV_FUNCTIONALS=Interpolation,RadiusSearch ./benchmarks/run_benchmarks.sh
PHYSICSNEMO_ASV_FUNCTIONALS=GridToPointInterpolation,RadiusSearch ./benchmarks/run_benchmarks.sh
```

Run only selected benchmark phases:

```bash
PHYSICSNEMO_ASV_PHASES=forward ./benchmarks/run_benchmarks.sh
PHYSICSNEMO_ASV_PHASES=forward,backward ./benchmarks/run_benchmarks.sh
```

Plots are written under:

- `docs/nn/functional/<functional_name>/benchmark.png`
- `docs/nn/functional/<category>/<functional_name>/benchmark_forward.png` (forward)
- `docs/nn/functional/<category>/<functional_name>/benchmark_backward.png` (backward)
Loading