Skip to content

Commit 4dfd684

Browse files
authored
Improve the capture of fatal cuda error (#10884)
This PR is a follow-up PR of #10630, which is to improve the capture of fatal cuda errors in libcudf and cudf java package. 1. libcudf: Removes the redundent call of `cudaGetLastError` in throw_cuda_error, since the call returning the cuda error can be deemed as the first call. 2. JNI: Leverages similar logic to discern fatal cuda errors from catched exceptions. The check at the JNI level is necessary because fatal cuda errors due to rmm APIs can not be distinguished. 3. Add C++ unit test for the capture of fatal cuda error 4. Add Java unit test for the capture of fatal cuda error Authors: - Alfred Xu (https://github.com/sperlingxx) Approvers: - Jake Hemstad (https://github.com/jrhemstad) - Jason Lowe (https://github.com/jlowe) URL: #10884
1 parent 4d138ef commit 4dfd684

File tree

6 files changed

+215
-50
lines changed

6 files changed

+215
-50
lines changed

cpp/include/cudf/utilities/error.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,10 +136,10 @@ namespace detail {
136136
// @cond
137137
inline void throw_cuda_error(cudaError_t error, const char* file, unsigned int line)
138138
{
139-
// Calls cudaGetLastError twice. It is nearly certain that a fatal error occurred if the second
140-
// call doesn't return with cudaSuccess.
139+
// Calls cudaGetLastError to clear the error status. It is nearly certain that a fatal error
140+
// occurred if it still returns the same error after a cleanup.
141141
cudaGetLastError();
142-
auto const last = cudaGetLastError();
142+
auto const last = cudaFree(0);
143143
auto const msg = std::string{"CUDA error encountered at: " + std::string{file} + ":" +
144144
std::to_string(line) + ": " + std::to_string(error) + " " +
145145
cudaGetErrorName(error) + " " + cudaGetErrorString(error)};

cpp/tests/error/error_handling_test.cu

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,11 @@
1616

1717
#include <cudf_test/base_fixture.hpp>
1818

19+
#include <cudf/filling.hpp>
1920
#include <cudf/utilities/error.hpp>
2021

2122
#include <rmm/cuda_stream.hpp>
2223

23-
#include <cstring>
24-
2524
TEST(ExpectsTest, FalseCondition)
2625
{
2726
EXPECT_THROW(CUDF_EXPECTS(false, "condition is false"), cudf::logic_error);
@@ -84,6 +83,25 @@ TEST(StreamCheck, CatchFailedKernel)
8483
"invalid configuration argument");
8584
}
8685

86+
__global__ void kernel(int* p) { *p = 42; }
87+
88+
TEST(DeathTest, CudaFatalError)
89+
{
90+
testing::FLAGS_gtest_death_test_style = "threadsafe";
91+
auto call_kernel = []() {
92+
int* p;
93+
cudaMalloc(&p, 2 * sizeof(int));
94+
int* misaligned = (int*)(reinterpret_cast<char*>(p) + 1);
95+
kernel<<<1, 1>>>(misaligned);
96+
try {
97+
CUDF_CUDA_TRY(cudaDeviceSynchronize());
98+
} catch (const cudf::fatal_cuda_error& fe) {
99+
std::abort();
100+
}
101+
};
102+
ASSERT_DEATH(call_kernel(), "");
103+
}
104+
87105
#ifndef NDEBUG
88106

89107
__global__ void assert_false_kernel() { cudf_assert(false && "this kernel should die"); }

java/pom.xml

Lines changed: 109 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@
136136
<groupId>org.apache.arrow</groupId>
137137
<artifactId>arrow-vector</artifactId>
138138
<version>${arrow.version}</version>
139-
<scope>test</scope>
139+
<scope>test</scope>
140140
</dependency>
141141
<dependency>
142142
<groupId>org.apache.parquet</groupId>
@@ -184,6 +184,42 @@
184184
<cxx.flags>-Wno-deprecated-declarations</cxx.flags>
185185
</properties>
186186
</profile>
187+
<profile>
188+
<id>default-tests</id>
189+
<build>
190+
<plugins>
191+
<plugin>
192+
<artifactId>maven-surefire-plugin</artifactId>
193+
<configuration>
194+
<excludes>
195+
<exclude>**/CudaFatalTest.java</exclude>
196+
</excludes>
197+
</configuration>
198+
<executions>
199+
<execution>
200+
<id>main-tests</id>
201+
<goals>
202+
<goal>test</goal>
203+
</goals>
204+
</execution>
205+
<execution>
206+
<id>fatal-cuda-test</id>
207+
<goals>
208+
<goal>test</goal>
209+
</goals>
210+
<configuration>
211+
<includes>
212+
<include>**/CudaFatalTest.java</include>
213+
</includes>
214+
<reuseForks>false</reuseForks>
215+
<test>*/CudaFatalTest.java</test>
216+
</configuration>
217+
</execution>
218+
</executions>
219+
</plugin>
220+
</plugins>
221+
</build>
222+
</profile>
187223
<profile>
188224
<id>no-cufile-tests</id>
189225
<activation>
@@ -199,8 +235,30 @@
199235
<configuration>
200236
<excludes>
201237
<exclude>**/CuFileTest.java</exclude>
238+
<exclude>**/CudaFatalTest.java</exclude>
202239
</excludes>
203240
</configuration>
241+
<executions>
242+
<execution>
243+
<id>main-tests</id>
244+
<goals>
245+
<goal>test</goal>
246+
</goals>
247+
</execution>
248+
<execution>
249+
<id>fatal-cuda-test</id>
250+
<goals>
251+
<goal>test</goal>
252+
</goals>
253+
<configuration>
254+
<includes>
255+
<include>**/CudaFatalTest.java</include>
256+
</includes>
257+
<reuseForks>false</reuseForks>
258+
<test>*/CudaFatalTest.java</test>
259+
</configuration>
260+
</execution>
261+
</executions>
204262
</plugin>
205263
</plugins>
206264
</build>
@@ -280,7 +338,7 @@
280338
<nexusUrl>https://oss.sonatype.org/</nexusUrl>
281339
<autoReleaseAfterClose>false</autoReleaseAfterClose>
282340
</configuration>
283-
</plugin>
341+
</plugin>
284342
</plugins>
285343
</build>
286344
</profile>
@@ -289,16 +347,16 @@
289347
<build>
290348
<resources>
291349
<resource>
292-
<!-- Include the properties file to provide the build information. -->
293-
<directory>${project.build.directory}/extra-resources</directory>
294-
<filtering>true</filtering>
350+
<!-- Include the properties file to provide the build information. -->
351+
<directory>${project.build.directory}/extra-resources</directory>
352+
<filtering>true</filtering>
295353
</resource>
296354
<resource>
297-
<directory>${basedir}/..</directory>
298-
<targetPath>META-INF</targetPath>
299-
<includes>
300-
<include>LICENSE</include>
301-
</includes>
355+
<directory>${basedir}/..</directory>
356+
<targetPath>META-INF</targetPath>
357+
<includes>
358+
<include>LICENSE</include>
359+
</includes>
302360
</resource>
303361
</resources>
304362
<pluginManagement>
@@ -339,6 +397,12 @@
339397
<artifactId>junit-jupiter-engine</artifactId>
340398
<version>5.4.2</version>
341399
</dependency>
400+
<dependency>
401+
<!-- to get around bug https://github.com/junit-team/junit5/issues/1367 -->
402+
<groupId>org.apache.maven.surefire</groupId>
403+
<artifactId>surefire-logger-api</artifactId>
404+
<version>2.21.0</version>
405+
</dependency>
342406
</dependencies>
343407
</plugin>
344408
<plugin>
@@ -404,9 +468,10 @@
404468
<arg value="${parallel.level}"/>
405469
</exec>
406470
<mkdir dir="${project.build.directory}/extra-resources"/>
407-
<exec executable="bash" output="${project.build.directory}/extra-resources/cudf-java-version-info.properties">
408-
<arg value="${project.basedir}/buildscripts/build-info"/>
409-
<arg value="${project.version}"/>
471+
<exec executable="bash"
472+
output="${project.build.directory}/extra-resources/cudf-java-version-info.properties">
473+
<arg value="${project.basedir}/buildscripts/build-info"/>
474+
<arg value="${project.version}"/>
410475
</exec>
411476
</tasks>
412477
</configuration>
@@ -428,31 +493,31 @@
428493
</goals>
429494
<configuration>
430495
<source>
431-
def sout = new StringBuffer(), serr = new StringBuffer()
432-
//This only works on linux
433-
def proc = 'ldd ${native.build.path}/libcudfjni.so'.execute()
434-
proc.consumeProcessOutput(sout, serr)
435-
proc.waitForOrKill(10000)
436-
def libcudf = ~/libcudf.*\\.so\\s+=>\\s+(.*)libcudf.*\\.so\\s+.*/
437-
def cudfm = libcudf.matcher(sout)
438-
if (cudfm.find()) {
439-
pom.properties['native.cudf.path'] = cudfm.group(1)
440-
} else {
441-
fail("Could not find cudf as a dependency of libcudfjni out> $sout err> $serr")
442-
}
496+
def sout = new StringBuffer(), serr = new StringBuffer()
497+
//This only works on linux
498+
def proc = 'ldd ${native.build.path}/libcudfjni.so'.execute()
499+
proc.consumeProcessOutput(sout, serr)
500+
proc.waitForOrKill(10000)
501+
def libcudf = ~/libcudf.*\\.so\\s+=>\\s+(.*)libcudf.*\\.so\\s+.*/
502+
def cudfm = libcudf.matcher(sout)
503+
if (cudfm.find()) {
504+
pom.properties['native.cudf.path'] = cudfm.group(1)
505+
} else {
506+
fail("Could not find cudf as a dependency of libcudfjni out> $sout err> $serr")
507+
}
443508

444-
def nvccout = new StringBuffer(), nvccerr = new StringBuffer()
445-
def nvccproc = 'nvcc --version'.execute()
446-
nvccproc.consumeProcessOutput(nvccout, nvccerr)
447-
nvccproc.waitForOrKill(10000)
448-
def cudaPattern = ~/Cuda compilation tools, release ([0-9]+)/
449-
def cm = cudaPattern.matcher(nvccout)
450-
if (cm.find()) {
451-
def classifier = 'cuda' + cm.group(1)
452-
pom.properties['cuda.classifier'] = classifier
453-
} else {
454-
fail('could not find CUDA version')
455-
}
509+
def nvccout = new StringBuffer(), nvccerr = new StringBuffer()
510+
def nvccproc = 'nvcc --version'.execute()
511+
nvccproc.consumeProcessOutput(nvccout, nvccerr)
512+
nvccproc.waitForOrKill(10000)
513+
def cudaPattern = ~/Cuda compilation tools, release ([0-9]+)/
514+
def cm = cudaPattern.matcher(nvccout)
515+
if (cm.find()) {
516+
def classifier = 'cuda' + cm.group(1)
517+
pom.properties['cuda.classifier'] = classifier
518+
} else {
519+
fail('could not find CUDA version')
520+
}
456521
</source>
457522
</configuration>
458523
</execution>
@@ -480,13 +545,13 @@
480545
<groupId>org.apache.maven.plugins</groupId>
481546
<artifactId>maven-surefire-plugin</artifactId>
482547
<configuration>
483-
<!-- you can turn this off, by passing -DtrimStackTrace=true when running tests -->
484-
<trimStackTrace>false</trimStackTrace>
485-
<redirectTestOutputToFile>true</redirectTestOutputToFile>
486-
<systemPropertyVariables>
487-
<ai.rapids.refcount.debug>${ai.rapids.refcount.debug}</ai.rapids.refcount.debug>
488-
<ai.rapids.cudf.nvtx.enabled>${ai.rapids.cudf.nvtx.enabled}</ai.rapids.cudf.nvtx.enabled>
489-
</systemPropertyVariables>
548+
<!-- you can turn this off, by passing -DtrimStackTrace=true when running tests -->
549+
<trimStackTrace>false</trimStackTrace>
550+
<redirectTestOutputToFile>true</redirectTestOutputToFile>
551+
<systemPropertyVariables>
552+
<ai.rapids.refcount.debug>${ai.rapids.refcount.debug}</ai.rapids.refcount.debug>
553+
<ai.rapids.cudf.nvtx.enabled>${ai.rapids.cudf.nvtx.enabled}</ai.rapids.cudf.nvtx.enabled>
554+
</systemPropertyVariables>
490555
</configuration>
491556
</plugin>
492557
<plugin>

java/src/main/native/include/jni_utils.hpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -862,6 +862,16 @@ inline void jni_cuda_check(JNIEnv *const env, cudaError_t cuda_status) {
862862
JNI_CHECK_CUDA_ERROR(env, cudf::jni::CUDA_ERROR_CLASS, e, ret_val); \
863863
} \
864864
catch (const std::exception &e) { \
865+
/* Double check whether the thrown exception is unrecoverable CUDA error or not. */ \
866+
/* Like cudf::detail::throw_cuda_error, it is nearly certain that a fatal error */ \
867+
/* occurred if the second call doesn't return with cudaSuccess. */ \
868+
cudaGetLastError(); \
869+
auto const last = cudaFree(0); \
870+
if (cudaSuccess != last && last == cudaDeviceSynchronize()) { \
871+
auto msg = e.what() == nullptr ? std::string{""} : e.what(); \
872+
auto cuda_error = cudf::fatal_cuda_error{msg, last}; \
873+
JNI_CHECK_CUDA_ERROR(env, cudf::jni::CUDA_FATAL_ERROR_CLASS, cuda_error, ret_val); \
874+
} \
865875
/* If jni_exception caught then a Java exception is pending and this will not overwrite it. */ \
866876
JNI_CHECK_THROW_NEW(env, class_name, e.what(), ret_val); \
867877
}
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/*
2+
* Copyright (c) 2022, NVIDIA CORPORATION.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package ai.rapids.cudf;
18+
19+
import org.junit.jupiter.api.Test;
20+
21+
import static org.junit.jupiter.api.Assertions.assertEquals;
22+
import static org.junit.jupiter.api.Assertions.assertThrows;
23+
24+
public class CudaFatalTest {
25+
26+
@Test
27+
public void testCudaFatalException() {
28+
try (ColumnVector cv = ColumnVector.fromInts(1, 2, 3, 4, 5)) {
29+
30+
try (ColumnView badCv = ColumnView.fromDeviceBuffer(new BadDeviceBuffer(), 0, DType.INT8, 256);
31+
ColumnView ret = badCv.sub(badCv);
32+
HostColumnVector hcv = ret.copyToHost()) {
33+
} catch (CudaException ignored) {
34+
}
35+
36+
// CUDA API invoked by libcudf failed because of previous unrecoverable fatal error
37+
assertThrows(CudaFatalException.class, () -> {
38+
try (ColumnVector cv2 = cv.asLongs()) {
39+
} catch (CudaFatalException ex) {
40+
assertEquals(CudaException.CudaError.cudaErrorIllegalAddress, ex.cudaError);
41+
throw ex;
42+
}
43+
});
44+
}
45+
46+
// CUDA API invoked by RMM failed because of previous unrecoverable fatal error
47+
assertThrows(CudaFatalException.class, () -> {
48+
try (ColumnVector cv = ColumnVector.fromBoxedInts(1, 2, 3, 4, 5)) {
49+
} catch (CudaFatalException ex) {
50+
assertEquals(CudaException.CudaError.cudaErrorIllegalAddress, ex.cudaError);
51+
throw ex;
52+
}
53+
});
54+
}
55+
56+
private static class BadDeviceBuffer extends BaseDeviceMemoryBuffer {
57+
public BadDeviceBuffer() {
58+
super(256L, 256L, (MemoryBufferCleaner) null);
59+
}
60+
61+
@Override
62+
public MemoryBuffer slice(long offset, long len) {
63+
return null;
64+
}
65+
}
66+
67+
}

java/src/test/java/ai/rapids/cudf/CudaTest.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818

1919
import org.junit.jupiter.api.Test;
2020

21-
import static org.junit.jupiter.api.Assertions.*;
21+
import static org.junit.jupiter.api.Assertions.assertEquals;
22+
import static org.junit.jupiter.api.Assertions.assertThrows;
2223

2324
public class CudaTest {
2425

@@ -44,5 +45,9 @@ public void testCudaException() {
4445
}
4546
}
4647
);
48+
// non-fatal CUDA error will not fail subsequent CUDA calls
49+
try (ColumnVector cv = ColumnVector.fromBoxedInts(1, 2, 3, 4, 5)) {
50+
}
4751
}
52+
4853
}

0 commit comments

Comments
 (0)