Skip to content

Commit 8859672

Browse files
committed
Surface CUDA launch errors
CUDA launch failures were previously written to native stderr but ignored by the Python launch callers. That let simulations and tape backward replay continue with stale outputs or gradients after CUDA rejected a launch. Check the existing wp_cuda_launch_kernel return value in direct launches, recorded Launch replay, JAX FFI, and APIC loaded-graph replay. This keeps the hot path to a single branch and avoids adding synchronization or extra CUDA queries. Add CUDA regressions for oversized block dimensions, launch_bounds violations, recorded commands, adjoint launches, and Tape.backward, and record the user-facing fix in the changelog. Signed-off-by: Eric Shi <ershi@nvidia.com>
1 parent 2e8482b commit 8859672

6 files changed

Lines changed: 132 additions & 11 deletions

File tree

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@
5454
instead of an internal `AttributeError` ([GH-1487](https://github.com/NVIDIA/warp/issues/1487)).
5555
- Fix stale gradient keepalive references when replacing a `@wp.struct` plain-array field with `None` or a
5656
non-gradient array ([GH-1520](https://github.com/NVIDIA/warp/issues/1520)).
57+
- Fix CUDA kernel launch failures to raise a Python `RuntimeError` instead of only logging native CUDA stderr and
58+
continuing with stale outputs or gradients ([GH-1535](https://github.com/NVIDIA/warp/issues/1535)).
5759

5860
### Documentation
5961

warp/_src/context.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8580,6 +8580,11 @@ def invoke(kernel, hooks, params: Sequence[Any], adjoint: bool):
85808580
hooks.backward(ctypes.byref(params[0]), ctypes.byref(args), ctypes.byref(adj_args))
85818581

85828582

8583+
def _raise_cuda_launch_error(kernel: Kernel, device: Device) -> None:
8584+
"""Raise a RuntimeError for the current CUDA launch failure."""
8585+
raise RuntimeError(f"Error launching kernel: {kernel.key} on device {device}: {runtime.get_error_string()}")
8586+
8587+
85838588
class Launch:
85848589
"""Represent all data required for a kernel launch so that launches can be replayed quickly.
85858590

@@ -8786,7 +8791,7 @@ def launch(self, stream: Stream | None = None) -> None:
87868791
graph._retain_module_exec(self.module_exec)
87878792

87888793
if self.adjoint:
8789-
runtime.core.wp_cuda_launch_kernel(
8794+
if runtime.core.wp_cuda_launch_kernel(
87908795
self.device.context,
87918796
self.hooks.backward,
87928797
self.bounds.size,
@@ -8796,9 +8801,10 @@ def launch(self, stream: Stream | None = None) -> None:
87968801
self.params_addr,
87978802
stream.cuda_stream,
87988803
None, # apic_info: replayed launches don't re-record
8799-
)
8804+
):
8805+
_raise_cuda_launch_error(self.kernel, self.device)
88008806
else:
8801-
runtime.core.wp_cuda_launch_kernel(
8807+
if runtime.core.wp_cuda_launch_kernel(
88028808
self.device.context,
88038809
self.hooks.forward,
88048810
self.bounds.size,
@@ -8808,7 +8814,8 @@ def launch(self, stream: Stream | None = None) -> None:
88088814
self.params_addr,
88098815
stream.cuda_stream,
88108816
None, # apic_info: replayed launches don't re-record
8811-
)
8817+
):
8818+
_raise_cuda_launch_error(self.kernel, self.device)
88128819

88138820

88148821
def _canonicalize_dim(dim: int | Sequence[int]) -> tuple[int, ...]:
@@ -9127,7 +9134,7 @@ def pack_args(args, params, adjoint=False):
91279134
"Backward kernel launches are not supported during APIC graph capture. "
91289135
"Use wp.Tape outside of capture scope instead."
91299136
)
9130-
runtime.core.wp_cuda_launch_kernel(
9137+
if runtime.core.wp_cuda_launch_kernel(
91319138
device.context,
91329139
hooks.backward,
91339140
bounds.size,
@@ -9137,7 +9144,8 @@ def pack_args(args, params, adjoint=False):
91379144
kernel_params,
91389145
stream.cuda_stream,
91399146
None,
9140-
)
9147+
):
9148+
_raise_cuda_launch_error(kernel, device)
91419149

91429150
else:
91439151
if hooks.forward is None:
@@ -9170,7 +9178,7 @@ def pack_args(args, params, adjoint=False):
91709178
False,
91719179
)
91729180
apic_info_ptr = ctypes.byref(apic_info)
9173-
runtime.core.wp_cuda_launch_kernel(
9181+
if runtime.core.wp_cuda_launch_kernel(
91749182
device.context,
91759183
hooks.forward,
91769184
bounds.size,
@@ -9180,7 +9188,8 @@ def pack_args(args, params, adjoint=False):
91809188
kernel_params,
91819189
stream.cuda_stream,
91829190
apic_info_ptr,
9183-
)
9191+
):
9192+
_raise_cuda_launch_error(kernel, device)
91849193

91859194
try:
91869195
runtime.verify_cuda_device(device)

warp/_src/jax/ffi.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -451,7 +451,7 @@ def ffi_callback(self, call_frame):
451451
assert hooks.forward, "Failed to find kernel entry point"
452452

453453
# launch the kernel
454-
wp._src.context.runtime.core.wp_cuda_launch_kernel(
454+
if wp._src.context.runtime.core.wp_cuda_launch_kernel(
455455
device.context,
456456
hooks.forward,
457457
launch_bounds.size,
@@ -461,7 +461,11 @@ def ffi_callback(self, call_frame):
461461
kernel_params,
462462
stream,
463463
None, # apic_info
464-
)
464+
):
465+
raise RuntimeError(
466+
f"Error launching kernel: {self.kernel.key} on device {device}: "
467+
f"{wp._src.context.runtime.get_error_string()}"
468+
)
465469

466470
except Exception as e:
467471
print(traceback.format_exc())

warp/native/apic.cu

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -372,13 +372,17 @@ static bool apic_rebuild_cuda_graph(APICGraph* graph, CUstream stream)
372372
// Replay via the same wp_cuda_launch_kernel that captured this op.
373373
// apic_info=nullptr is safe: g_apic_state is null during replay, so
374374
// the recording branch in wp_cuda_launch_kernel is a no-op.
375-
wp_cuda_launch_kernel(
375+
size_t launch_result = wp_cuda_launch_kernel(
376376
graph->cuda_context, kernel, rec->dim, rec->max_blocks, rec->block_dim, rec->smem_bytes, args.data(),
377377
stream, /*apic_info=*/nullptr
378378
);
379379

380380
for (uint8_t* p : arg_storage)
381381
delete[] p;
382+
if (launch_result) {
383+
success = false;
384+
break;
385+
}
382386
break;
383387
}
384388

warp/tests/test_launch.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@ def square_kernel(input: wp.array(dtype=float), output: wp.array(dtype=float)):
4848
output[i] = input[i] * input[i]
4949

5050

51+
@wp.kernel
52+
def noop_kernel():
53+
tid = wp.tid()
54+
55+
5156
def test1d(test, device):
5257
a = np.arange(0, dim_x).reshape(dim_x)
5358

@@ -400,6 +405,12 @@ def kernel_single_tuple_bound(x: wp.array(dtype=float)):
400405
x[tid] = x[tid] * 2.0
401406

402407

408+
@wp.kernel(launch_bounds=256)
409+
def bounded_square_kernel(data: wp.array(dtype=float), output: wp.array(dtype=float)):
410+
i = wp.tid()
411+
output[i] = data[i] * data[i]
412+
413+
403414
def test_launch_bounds_none(test, device):
404415
"""Test kernel without launch_bounds"""
405416
n = 1024
@@ -436,7 +447,61 @@ def test_launch_bounds_single_tuple(test, device):
436447
assert_np_equal(x.numpy(), np.full(n, 2.0, dtype=np.float32))
437448

438449

450+
def test_launch_device_block_dim_failure(test, device):
451+
"""Raise when CUDA rejects an oversized launch block.
452+
453+
Protects users from continuing after native stderr with kernel outputs left unchanged.
454+
"""
455+
with test.assertRaisesRegex(RuntimeError, r"Error launching kernel: .*noop_kernel.*Warp CUDA error"):
456+
wp.launch(noop_kernel, dim=1, block_dim=2048, device=device)
457+
458+
459+
def test_launch_bounds_block_dim_failure(test, device):
460+
"""Raise when CUDA rejects a launch-bounds violation.
461+
462+
Protects users from silently skipping kernels whose outputs feed later simulation stages.
463+
"""
464+
x = wp.ones(1, dtype=float, device=device)
465+
466+
with test.assertRaisesRegex(RuntimeError, r"Error launching kernel: .*kernel_single_bound.*Warp CUDA error"):
467+
wp.launch(kernel_single_bound, dim=1, inputs=[x], block_dim=512, device=device)
468+
469+
470+
def test_launch_cmd_block_dim_failure(test, device):
471+
"""Raise when recorded launches hit CUDA launch errors.
472+
473+
Protects recorded command replay from returning normally with stale outputs.
474+
"""
475+
x = wp.ones(1, dtype=float, device=device)
476+
cmd = wp.launch(kernel_single_bound, dim=1, inputs=[x], block_dim=512, device=device, record_cmd=True)
477+
478+
with test.assertRaisesRegex(RuntimeError, r"Error launching kernel: .*kernel_single_bound.*Warp CUDA error"):
479+
cmd.launch()
480+
481+
482+
def test_launch_adjoint_block_dim_failure(test, device):
483+
"""Raise when adjoint launches hit CUDA launch errors.
484+
485+
Protects differentiable simulations from using missing or partial gradients.
486+
"""
487+
input_arr = wp.array([1.0], dtype=float, requires_grad=True, device=device)
488+
output_arr = wp.empty_like(input_arr)
489+
output_arr.grad.fill_(1.0)
490+
491+
with test.assertRaisesRegex(RuntimeError, r"Error launching kernel: .*bounded_square_kernel.*Warp CUDA error"):
492+
wp.launch(
493+
bounded_square_kernel,
494+
dim=input_arr.size,
495+
inputs=[input_arr, output_arr],
496+
adj_inputs=[None, None],
497+
adjoint=True,
498+
block_dim=512,
499+
device=device,
500+
)
501+
502+
439503
devices = get_test_devices()
504+
cuda_devices = get_cuda_test_devices()
440505

441506

442507
class TestLaunch(unittest.TestCase):
@@ -462,6 +527,18 @@ class TestLaunch(unittest.TestCase):
462527
add_function_test(TestLaunch, "test_launch_bounds_single", test_launch_bounds_single, devices=devices)
463528
add_function_test(TestLaunch, "test_launch_bounds_tuple", test_launch_bounds_tuple, devices=devices)
464529
add_function_test(TestLaunch, "test_launch_bounds_single_tuple", test_launch_bounds_single_tuple, devices=devices)
530+
add_function_test(
531+
TestLaunch, "test_launch_device_block_dim_failure", test_launch_device_block_dim_failure, devices=cuda_devices
532+
)
533+
add_function_test(
534+
TestLaunch, "test_launch_bounds_block_dim_failure", test_launch_bounds_block_dim_failure, devices=cuda_devices
535+
)
536+
add_function_test(
537+
TestLaunch, "test_launch_cmd_block_dim_failure", test_launch_cmd_block_dim_failure, devices=cuda_devices
538+
)
539+
add_function_test(
540+
TestLaunch, "test_launch_adjoint_block_dim_failure", test_launch_adjoint_block_dim_failure, devices=cuda_devices
541+
)
465542

466543

467544
if __name__ == "__main__":

warp/tests/test_tape.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,29 @@ def test_tape_visualize_subscript(test, device):
303303
test.assertIn("array: dtype=", dot_code)
304304

305305

306+
def test_tape_backward_cuda_launch_failure(test, device):
307+
"""Raise when Tape backward hits CUDA launch errors.
308+
309+
Corrupt the recorded backward launch block size to reproduce the stale-gradient failure mode.
310+
Protects ``Tape.backward()`` from returning after a failed replay with stale or missing gradients.
311+
"""
312+
x = wp.array([1.0], dtype=wp.float32, device=device, requires_grad=True)
313+
y = wp.empty_like(x, requires_grad=True)
314+
315+
tape = wp.Tape()
316+
with tape:
317+
wp.launch(kernel=mul_constant, dim=x.size, inputs=[x], outputs=[y], block_dim=256, device=device)
318+
319+
launch = tape.launches[0]
320+
test.assertEqual(len(launch), 8)
321+
launch[6] = 2048 # block_dim
322+
323+
with test.assertRaisesRegex(RuntimeError, r"Error launching kernel: .*mul_constant.*Warp CUDA error"):
324+
tape.backward(grads={y: wp.ones_like(y)})
325+
326+
306327
devices = get_test_devices()
328+
cuda_devices = get_cuda_test_devices()
307329

308330

309331
class TestTape(unittest.TestCase):
@@ -367,6 +389,9 @@ def test_tape_empty_nested_scope_markers_removed(self):
367389
add_function_test(TestTape, "test_tape_struct_subscript", test_tape_struct_subscript, devices=devices)
368390
add_function_test(TestTape, "test_tape_nested_struct_subscript", test_tape_nested_struct_subscript, devices=devices)
369391
add_function_test(TestTape, "test_tape_visualize_subscript", test_tape_visualize_subscript, devices=devices)
392+
add_function_test(
393+
TestTape, "test_tape_backward_cuda_launch_failure", test_tape_backward_cuda_launch_failure, devices=cuda_devices
394+
)
370395

371396

372397
if __name__ == "__main__":

0 commit comments

Comments
 (0)