Skip to content

v1.14.0

Latest

Choose a tag to compare

@github-actions github-actions released this 01 Jun 15:29
· 29 commits to main since this release
v1.14.0
b943176

Warp v1.14.0

Warp v1.14 expands serialized CPU capture support: captured graphs can now include backward launches, tiled kernels, richer launch arguments, and .wrp files with arrays nested inside structs or wp.indexedarray arguments. This release also adds multi-environment FEM support for batched simulations, reusable and batched linear solvers, pluggable logging, portable tile FFT and solver fallbacks, stable JAX integration APIs, and relaxed CPU/GPU array access for Heterogeneous Memory Management (HMM) and Address Translation Services (ATS) systems.

New features

API Capture expands to cover more workflows

Building on the initial API Capture serialization support in Warp v1.13, Warp v1.14 primarily broadens the set of CPU graph patterns that can be saved and replayed. CPU captures can now include forward execution and reverse-mode passes from wp.Tape().backward(), wp.launch_tiled() kernels, and scalar parameters of any size (#1431). The shared .wrp serialization format also now supports @wp.struct arguments that contain arrays and wp.indexedarray arguments that carry data, gradient, and index buffers.

Important

Upgrade impact for APIC users:

  • Recapture .wrp files saved by Warp v1.13. Warp v1.14 writes APIC format version 10 and rejects the previous format.
  • Update native C/C++ APIC handle declarations to explicit pointers such as APICState* and APICGraph*. Ownership and destroy calls are unchanged. See the APIC migration diff.

Saved APIC graphs can still be consumed from standalone C++ through the C API declared in warp/native/apic.h. Native replay behavior is unchanged apart from the explicit pointer spelling for APIC handles.

Key capabilities:

  • Reverse-mode replay on CPU: adjoint launches emitted by wp.Tape().backward() are recorded into the CPU APIC stream and replayed from live captures or loaded graphs.
  • Richer launch arguments: APIC now relocates array data pointers, gradient pointers, index pointers for wp.indexedarray arguments, and handles inside serialized launch value blobs.
  • Tiled kernels: CPU captures can replay kernels that use tiles, including reductions and scans that previously fell outside the captured operation set.

Known limitations:

  • wp.utils.array_scan() is still not recorded into CPU APIC and raises NotImplementedError in CPU capture.
  • Nonzero array.fill_() operations on CPU are not recorded.
  • APIC serializes wp.Mesh handles, but wp.Volume and wp.Bvh handles are not yet supported.
  • Loading CPU .wrp graphs requires the warp-clang backend and the companion _modules/ directory with the recorded CPU kernel objects.

Multi-environment warp.fem

warp.fem can now represent many independent simulation environments inside one geometry and one solve setup (#1407). Colocated Grid2D and Grid3D geometries expose an env_count, sparse Nanogrid and AdaptiveNanogrid geometries can pack per-environment voxels into one NanoVDB volume, and unstructured meshes can carry per-cell environment metadata through cell_env and env_count.

This feature changes two positional call signatures. See the FEM migration diff if your code passes requires_grad, device, or temporary_store positionally.

import warp as wp
import warp.fem as fem

geo = fem.Grid3D(res=(8, 8, 8), bounds_lo=wp.vec3(0.0), bounds_hi=wp.vec3(1.0), env_count=4)
pressure_space = fem.make_polynomial_space(geo, degree=0, discontinuous=True)
partition = fem.make_space_partition(space_topology=pressure_space.topology, environment_first=True)

# For scalar pressure spaces, these node offsets can be passed directly to
# warp.optim.linear.LinearOperator(batch_offsets=...).
pressure_batch_offsets = partition.env_offsets

Environment-aware lookup keeps colocated environments from interacting accidentally. When a geometry has more than one environment, pass an environment index to fem.lookup() and pass env_indices to fem.PicQuadrature when particles are binned from world-space positions. The new warp/examples/fem/example_apic_fluid_multi_env.py example uses these APIs to run colocated APIC fluid environments with environment-aware particle quadrature and batched pressure solves.

Key capabilities:

  • Colocated environments: grid environments can overlap in world coordinates while remaining topologically independent.
  • Sparse packed environments: Nanogrid.from_environment_voxels() and AdaptiveNanogrid.from_environment_voxels() build packed sparse grids with per-cell cell_env metadata and hidden offsets for packed grids.
  • Mesh environments: FEM mesh constructors can use cell_env and env_count so grouped BVH lookup only traverses the requested environment.
  • Batched solves: make_space_partition(..., environment_first=True) exposes env_offsets that line up with the new linear-solver batch_offsets support for scalar spaces.
  • Known limitations: environment_first=True does not support halo nodes for partitions that do not cover a whole geometry. Mesh environment indices are lookup and partition metadata, so callers must still provide disconnected mesh topology for independent mesh environments.

Reusable and batched linear solvers

The iterative solvers in warp.optim.linear can now preallocate their temporary buffers and reuse them across compatible solves (#1391). Passing run=False to cg(), cr(), bicgstab(), or gmres() returns a solver object that can be called repeatedly with replacement operands that have the same shape, dtype, device, and batch layout.

import numpy as np
import warp as wp
from warp.optim.linear import aslinearoperator, cg

diag = wp.array([2.0, 2.0, 5.0, 5.0], dtype=float, device="cpu")
b = wp.array([2.0, 4.0, 10.0, 15.0], dtype=float, device="cpu")
x = wp.zeros_like(b)
offsets = wp.array(np.array([0, 2, 4], dtype=np.int32), dtype=int, device="cpu")

A = aslinearoperator(diag, batch_offsets=offsets)
state = cg(A, b, x, maxiter=100, run=False)
state()  # solve the original system

b2 = wp.array([4.0, 2.0, 15.0, 10.0], dtype=float, device="cpu")
x2 = wp.zeros_like(b2)
state(b=b2, x=x2)  # reuse temporary buffers for a compatible system

batch_offsets is independent of solver-state reuse. It partitions a LinearOperator into scalar degree-of-freedom intervals that are solved as independent subproblems in one solver launch sequence, with convergence checked per batch and reported through the worst residual. For one-shot solves, call the solver directly as before.

Pluggable logging

Warp's Python-side diagnostics now flow through a configurable logger (#1315, #1434). Install a custom logger with wp.set_logger(), scope it with wp.ScopedLogger, and control verbosity through wp.config.log_level or wp.ScopedLogLevel.

wp.config.log_level accepts the standard Warp log-level constants:

  • wp.LOG_DEBUG: most verbose, including code-generation details and module loads.
  • wp.LOG_INFO: the default, including the init banner and compile timings.
  • wp.LOG_WARNING: warnings and errors only.
  • wp.LOG_ERROR: errors only.
import warp as wp

wp.config.log_level = wp.LOG_WARNING

class ListLogger:
    def __init__(self):
        self.records = []
    def debug(self, message):
        self.records.append(("debug", message))
    def info(self, message):
        self.records.append(("info", message))
    def warning(self, message, category=None, stacklevel=1):
        self.records.append(("warning", message))
    def error(self, message):
        self.records.append(("error", message))

logger = ListLogger()
with wp.ScopedLogger(logger), wp.ScopedLogLevel(wp.LOG_INFO):
    wp.get_logger().info("captured by application logger")
print(logger.records)  # [('info', 'captured by application logger')]

The default Warp logger now routes warnings emitted by Warp's Python code through Python's warnings.warn(), so application warning filters can suppress Warp deprecation warnings again. wp.config.verbose and wp.config.quiet still work during the deprecation window, but they now emit one-time DeprecationWarnings and map to wp.config.log_level.

Relaxed CPU/GPU array access

Warp no longer enforces the old rule that every Warp array passed to a kernel must be allocated on the launch device. wp.config.launch_array_access_mode now defaults to wp.config.LaunchArrayAccessMode.RELAXED, so CUDA kernels can receive CPU arrays on hardware where CUDA reports pageable CPU memory as GPU-accessible (#1461). The main targets are Linux Heterogeneous Memory Management (HMM) systems and NVIDIA Address Translation Services (ATS) platforms such as GH200, GB200, DGX Spark / GB10, and Jetson Thor, where ordinary CPU allocations can be GPU-accessible without an explicit .to(device) copy. Use wp.can_access() for concrete arrays and Device.can_access() for coarse device checks when code needs to choose a direct launch or an explicit copy at runtime.

import warp as wp

@wp.kernel
def increment(data: wp.array[float]):
    i = wp.tid()
    data[i] = data[i] + 1.0

device = wp.get_device("cuda:0")
cpu_data = wp.empty(1024, dtype=float, device="cpu")

if wp.can_access(device, cpu_data):
    wp.launch(increment, dim=cpu_data.size, inputs=[cpu_data], device=device)
else:
    gpu_data = cpu_data.to(device)
    wp.launch(increment, dim=gpu_data.size, inputs=[gpu_data], device=device)

Choose the validation mode based on how much pre-launch checking you want:

  • wp.config.LaunchArrayAccessMode.RELAXED is the default. It checks type, dtype, and dimensions, but does not reject cross-device array arguments before launch. If a GPU kernel dereferences CPU memory that the device cannot access, the failure surfaces as a CUDA runtime error instead of Warp's previous Python same-device error.
  • wp.config.LaunchArrayAccessMode.CHECKED raises a Python error before launch when Warp can prove that an array is not accessible from the launch device. It warns and proceeds for custom or externally wrapped allocations whose provenance Warp cannot verify.
  • wp.config.LaunchArrayAccessMode.STRICT restores the old same-device rule and requires every Warp array argument to be allocated on the launch device.

Most users can keep the default. Use wp.config.LaunchArrayAccessMode.CHECKED when diagnosing mixed-device launches, and use wp.config.LaunchArrayAccessMode.STRICT in tests or libraries that depend on pre-launch same-device validation.

CUDA graph capture modes

CUDA graph capture now exposes CUDA's stream-capture mode through the capture_mode keyword on wp.ScopedCapture and wp.capture_begin() (#1410). Choose the mode based on how strictly CUDA should reject capture-unsafe runtime calls while capture is active:

  • wp.CaptureMode.THREAD_LOCAL is the default and preserves Warp's historical behavior. Capture-unsafe runtime calls from the capturing thread invalidate the capture, while other threads are unaffected.
  • wp.CaptureMode.GLOBAL is the strictest mode. Capture-unsafe runtime calls from any thread invalidate the capture.
  • wp.CaptureMode.RELAXED tolerates capture-unsafe runtime calls. Use it when composing with libraries that may lazily initialize CUDA contexts or allocators during capture.
import warp as wp

@wp.kernel
def add_one(x: wp.array[float]):
    i = wp.tid()
    x[i] = x[i] + 1.0

x = wp.zeros(1024, dtype=float, device="cuda:0")

with wp.ScopedCapture(device="cuda:0", capture_mode=wp.CaptureMode.RELAXED) as capture:
    wp.launch(add_one, dim=x.size, inputs=[x], device="cuda:0")

wp.capture_launch(capture.graph)

The function form uses the same capture_mode keyword, for example wp.capture_begin(device="cuda:0", capture_mode=wp.CaptureMode.RELAXED).

Tile programming enhancements

Portable tile FFT and solver fallbacks

wp.tile_fft() and wp.tile_ifft() now run on CPU and on GPU builds that do not include libmathdx (#1396). CPU supports any power-of-two FFT length and non-power-of-two lengths up to 4096 elements. The GPU fallback, selected automatically when libmathdx is unavailable or explicitly with enable_mathdx_fft=False, supports power-of-two FFT lengths divisible by block_dim.

import warp as wp

@wp.kernel
def filter_signal(x: wp.array2d[wp.vec2f], y: wp.array2d[wp.vec2f]):
    t = wp.tile_load(x, shape=(64, 64))
    wp.tile_fft(t)
    # Apply a spectral filter here.
    wp.tile_ifft(t)
    wp.tile_store(y, t)

GPU scalar fallbacks are also available for wp.tile_cholesky(), wp.tile_cholesky_solve(), wp.tile_lower_solve(), and wp.tile_upper_solve(), including in-place variants and the wp.tile_cholesky() adjoint (#1402). Select them with wp.config.enable_mathdx_solver=False or module_options={"enable_mathdx_solver": False}. The fallback avoids a libmathdx dependency and reduces compile cost, at the expense of runtime performance. One fallback limitation is that differentiated wp.tile_cholesky() allocates extra per-block scratch storage, so large GPU tiles can exceed the device's shared-memory budget. If you hit that limit, reduce the Cholesky tile size or dtype, or use a libmathdx-enabled build with wp.config.enable_mathdx_solver=True.

wp.tile_empty()

wp.tile_empty() allocates an uninitialized register or shared-memory tile for kernels that overwrite every element before the first read (#1312). Use it instead of wp.tile_zeros() for full overwrites, because skipping zero-fill work can improve performance. Keep wp.tile_zeros() for accumulators or partial writes where the initial zeros are part of correctness.

Tile reliability fixes

Several tile fixes improve correctness in edge cases:

  • wp.tile_load(), wp.tile_store(), and indexed tile operations now use 64-bit byte offsets, fixing overflows on arrays larger than 2 GiB (#1422). This correctness fix widens address arithmetic in the tile load/store paths, so tile-heavy kernels may see slightly higher address-calculation overhead.
  • Register-to-shared and shared-to-register tile assignment now works in both directions, including pipelined wp.tile_matmul() kernels and their adjoints (#1439, #1440).
  • wp.tile_matmul() now rejects wp.bfloat16 output tiles. It also rejects wp.bfloat16 input tiles when backward compilation is enabled, because the backend cannot use bfloat16 accumulators for those paths (#1427).

JAX integration graduates to stable API

The JAX integration has been promoted from warp.jax_experimental into Warp's stable public API (#1370). New code should import warp.jax_kernel, warp.jax_callable, warp.clear_jax_callable_graph_cache, warp.JaxCallableGraphMode, and warp.JaxModulePreloadMode directly from the top-level warp namespace.

- from warp.jax_experimental import GraphMode, ModulePreloadMode, clear_jax_callable_graph_cache
- from warp.jax_experimental import jax_callable, jax_kernel
+ from warp import JaxCallableGraphMode, JaxModulePreloadMode, clear_jax_callable_graph_cache
+ from warp import jax_callable, jax_kernel

As part of the promotion, warp.jax_experimental is now a deprecated compatibility namespace and will be removed in Warp 1.16. The warp.jax_experimental.get_jax_callable_default_graph_cache_max() and warp.jax_experimental.set_jax_callable_default_graph_cache_max() helpers are also deprecated. Pass graph_cache_max to warp.jax_callable() or update the returned callable's graph_cache_max attribute instead. Top-level warp.jax_callable() defaults to graph_cache_max=32. Pass graph_cache_max=None for an unlimited graph cache.

Differentiable warp.jax_kernel() wrappers now accept launch_dims together with enable_backward=True (#1380). The dimensions are fixed at wrapper construction and reused for both the forward and adjoint launches, which is useful when the input array includes batch or channel dimensions that are not part of the kernel's wp.tid() iteration space.

# API shape example. Requires JAX installed.
import warp as wp
from warp import jax_kernel

@wp.kernel
def spatial_update(x: wp.array4d[float], out: wp.array4d[float]):
    i, j, k = wp.tid()
    out[0, i, j, k] = x[0, i, j, k] * 2.0

jax_update = jax_kernel(
    spatial_update,
    num_outputs=1,
    launch_dims=(16, 16, 16),
    enable_backward=True,
)

When enable_backward=True, launch_dims cannot be overridden per call, and output_dims remains unsupported.

Compilation and source-build tooling

Source builds gain two new build options. build_lib.py --use-dynamic-cuda links Warp's native library against shared CUDA libraries instead of embedding them statically, for deployments that already provide the matching CUDA shared libraries at runtime (#1334). build_lib.py --sanitize=address builds the native libraries with AddressSanitizer instrumentation, a compiler/runtime memory-error detector for native out-of-bounds accesses, use-after-free, double-free, and similar bugs (#1387). Use it for debugging source builds when you want a failing test or repro to report the invalid memory access closer to where it happens.

uv run build_lib.py --use-dynamic-cuda
uv run build_lib.py --sanitize=address --mode debug

On Linux, install Python development headers if a source build fails with a missing Python.h, for example python3-dev or libpython3-dev on Debian and Ubuntu (#1339). Warp now compiles a small Python C API extension for faster wp.float16 conversion to and from Python float. The extension uses CPython's vectorcall protocol, the public fast-call convention introduced by PEP 590. Linux and macOS debug builds now use -Og -g, which keeps debug information while preserving useful compiler diagnostics such as uninitialized-value analysis (#1414).

Math, autodiff, and correctness fixes

Floating-point wp.min(), wp.max(), wp.clamp(), wp.atomic_min(), and wp.atomic_max() now use NaN-as-missing semantics matching C fmin() and fmax() (#1376). When exactly one operand is NaN, the non-NaN operand wins. Vector reductions and wp.argmin() / wp.argmax() skip NaN slots. Adjoint and atomic variants route gradients to the operand chosen by the forward pass.

import numpy as np
import warp as wp

@wp.kernel
def math_kernel(values: wp.array[float], out: wp.array[float]):
    out[0] = wp.min(values[0], 2.0)
    out[1] = wp.max(values[0], -3.0)
    out[2] = wp.copysign(3.0, values[1])

values = wp.array([np.nan, -0.0], dtype=float, device="cpu")
out = wp.zeros(3, dtype=float, device="cpu")
wp.launch(math_kernel, dim=1, inputs=[values], outputs=[out], device="cpu")
print(out.numpy())  # [ 2. -3. -3.]

Other math and autodiff fixes are grouped by affected path:

  • New builtin: wp.copysign() is now available in kernels (#1444).
  • Curl-noise gradients: gradients now propagate through wp.curlnoise() in 2D, 3D, and 4D, so differentiable curl-noise force fields no longer produce zero gradients (#1012).
  • Component and field writes: assignments on elements of a wp.array, such as arr[i].y = rhs, m[i][r, c] = rhs, transform .p / .q writes, and scalar or composite struct-field writes, now propagate gradients (#583, #248, #1174).
  • Geometry queries: wp.closest_point_edge_edge() is more reliable for near-parallel float32 segments (#1437). Well-conditioned inputs keep the same output, while near-parallel cases now use a more stable closest-point computation and a bounded analytic adjoint.
  • CUDA graph capture fills: wp.array.fill_() now passes fill values up to 3968 bytes inline through kernel arguments on both non-contiguous and contiguous array fill paths (#1412). This fixes forked-stream capture failures for scalar, vector/matrix, and many struct fills. Larger fill values still use the previous temporary-storage fallback and are not guaranteed to compose in the same forked-stream capture scenario.
  • Cleanup fixes: retained-graph allocation leaks are fixed (#1429), repeated module="unique" kernel declarations that depend on other Warp functions no longer retain stale Python module references (#1462), and autodiff metadata is now cleaned up for non-differentiable builtins (#988, #1466).

The component-write fix means the example below now gives x.grad == [2, 2, 2] because the adjoint crosses the stored out[i].y component instead of stopping at the component assignment.

import numpy as np
import warp as wp

@wp.kernel
def write_y(x: wp.array[float], out: wp.array[wp.vec3]):
    i = wp.tid()
    out[i].y = 2.0 * x[i]

@wp.kernel
def sum_y(out: wp.array[wp.vec3], loss: wp.array[float]):
    i = wp.tid()
    wp.atomic_add(loss, 0, out[i].y)

x = wp.array([1.0, 2.0, 3.0], dtype=float, device="cpu", requires_grad=True)
out = wp.zeros(x.size, dtype=wp.vec3, device="cpu", requires_grad=True)
loss = wp.zeros(1, dtype=float, device="cpu", requires_grad=True)

with wp.Tape() as tape:
    wp.launch(write_y, dim=x.size, inputs=[x, out], device="cpu")
    wp.launch(sum_y, dim=x.size, inputs=[out, loss], device="cpu")

tape.backward(loss)
np.testing.assert_allclose(x.grad.numpy(), [2.0, 2.0, 2.0])

Breaking changes

APIC file format and native graph handles (#1431)

Warp v1.14 writes APIC format version 10. .wrp files captured by Warp v1.13 used the older format and must be recaptured. Native C/C++ code that includes apic.h must update APIC handles from the old typedef form, which hid the pointer, to explicit pointers. Ownership and destroy calls are unchanged. The migration is a source-level spelling change:

- APICState state = wp_apic_create_state();
+ APICState* state = wp_apic_create_state();
  wp_apic_destroy_state(state);

- APICGraph graph = wp_apic_load_graph(context, "simulation", APIC_DEVICE_CUDA);
+ APICGraph* graph = wp_apic_load_graph(context, "simulation", APIC_DEVICE_CUDA);
  wp_apic_destroy_graph(graph);

FEM positional signatures (#1407)

warp.fem.PicQuadrature() now accepts env_indices before requires_grad, and warp.fem.make_space_partition() now accepts environment_first before the keyword-only device and temporary_store. Pass affected arguments by keyword to preserve behavior:

- q = fem.PicQuadrature(domain, particles, measures, True)
+ q = fem.PicQuadrature(domain, particles, measures, requires_grad=True)

- p = fem.make_space_partition(space_topology, geometry_partition, True, -1, device)
+ p = fem.make_space_partition(
+     space_topology=space_topology,
+     geometry_partition=geometry_partition,
+     with_halo=True,
+     max_node_count=-1,
+     device=device,
+ )

HashGrid query type annotations (#1452)

Generated docs and public stubs now expose wp.HashGridQuery as the single query type. Runtime aliases wp.HashGridQueryH and wp.HashGridQueryD still warn and forward during the deprecation window, but function annotations should migrate to avoid IDE or type-checker failures. Use unparameterized wp.HashGridQuery for the default wp.float32 query type, and parameterize it when the query uses another coordinate precision:

- query: wp.HashGridQueryH
+ query: wp.HashGridQuery[wp.float16]

- query: wp.HashGridQueryD
+ query: wp.HashGridQuery[wp.float64]

This is a source-typing change only. Runtime query objects and aliases remain compatible during the deprecation window.

Linux source builds require Python.h (#1339)

Linux source builds now compile a small Python C API extension for faster wp.float16 conversions, so the build needs Python.h. If a source build fails with a missing Python.h, install your distribution's Python development package before rebuilding Warp. The extension uses CPython's vectorcall protocol, the public fast-call convention introduced by PEP 590.

Announcements

Upcoming removals

  • warp.jax_experimental is deprecated. Import jax_kernel, jax_callable, clear_jax_callable_graph_cache, JaxCallableGraphMode, and JaxModulePreloadMode from top-level warp instead. The graph-cache default helpers are also deprecated. See JAX integration graduates to stable API for the import diff and graph-cache migration. The deprecated namespace will be removed in Warp 1.16 (#1370).
  • warp.config.verbose and warp.config.quiet are deprecated. Use warp.config.log_level = wp.LOG_DEBUG for verbose diagnostics and warp.config.log_level = wp.LOG_WARNING to suppress the init banner. See Pluggable logging for accepted log-level values. The legacy flags will be removed in a future feature release per the standard deprecation timeline (#1315).
  • wp.HashGridQueryH and wp.HashGridQueryD are deprecated. Use wp.HashGridQuery[wp.float16] or wp.HashGridQuery[wp.float64] in function annotations that need explicit query precision. Use plain wp.HashGridQuery for the default wp.float32 query type. See the HashGrid query annotation migration. Runtime aliases remain available during the deprecation window, but public stubs now expose only wp.HashGridQuery (#1452).
  • The scheduled FEM argument removals are deferred to Warp 1.15. The deprecated quadrature and domain arguments of warp.fem.interpolate(), plus the space argument of warp.fem.make_space_restriction() and warp.fem.make_space_partition(), remain available for this release and are now scheduled for removal in Warp 1.15.

Acknowledgments

We also thank the following contributors from outside the core Warp development team:

  • @flferretti for adding dynamic CUDA linking support to the build system.
  • @kshy0519 for allowing launch_dims with differentiable warp.jax_kernel() wrappers (#1380).

For a complete list of changes, see the full changelog.