@@ -76,6 +76,21 @@ list(REMOVE_ITEM TEST_SRC_FILES ${UNMATCH_FILES})
7676file (GLOB_RECURSE TEST_BASE_FILES ${PROJECT_SOURCE_DIR} /src/*.cpp)
7777set (PADDLE_TARGET_FOLDER ${CMAKE_BINARY_DIR} /paddle)
7878
79+ # ---------------------------------------------------------------------------
80+ # CUDA Toolkit (needed for CUDA-specific test headers in the Torch build)
81+ # ---------------------------------------------------------------------------
82+ find_package (CUDAToolkit QUIET )
83+ if (CUDAToolkit_FOUND)
84+ message (STATUS "Found CUDA Toolkit: ${CUDAToolkit_INCLUDE_DIRS} " )
85+ set (CUDA_INCLUDE_DIRS "${CUDAToolkit_INCLUDE_DIRS} " )
86+ elseif (EXISTS "/usr/local/cuda/include" )
87+ set (CUDA_INCLUDE_DIRS "/usr/local/cuda/include" )
88+ message (STATUS "Using default CUDA include dir: ${CUDA_INCLUDE_DIRS} " )
89+ else ()
90+ message (WARNING "CUDA headers not found; CUDA tests may not compile." )
91+ set (CUDA_INCLUDE_DIRS "" )
92+ endif ()
93+
7994# ---------------------------------------------------------------------------
8095# Build Torch test case
8196# ---------------------------------------------------------------------------
@@ -87,8 +102,9 @@ set(TORCH_LIBRARIES "")
87102file (GLOB_RECURSE TORCH_LIBRARIES "${TORCH_DIR} /lib/*.so"
88103 "${TORCH_DIR} /lib/*.a" )
89104
90- set (TORCH_INCLUDE_DIR "${TORCH_DIR} /include"
91- "${TORCH_DIR} /include/torch/csrc/api/include/" )
105+ set (TORCH_INCLUDE_DIR
106+ "${TORCH_DIR} /include" "${TORCH_DIR} /include/torch/csrc/api/include/"
107+ "${CUDA_INCLUDE_DIRS} " )
92108
93109set (TORCH_TARGET_FOLDER ${CMAKE_BINARY_DIR} /torch)
94110set (BIN_PREFIX "torch_" )
@@ -119,7 +135,8 @@ set(PADDLE_INCLUDE_DIR
119135 "${PADDLE_DIR} /include/third_party"
120136 "${PADDLE_DIR} /include/paddle/phi/api/include/compat/"
121137 "${PADDLE_DIR} /include/paddle/phi/api/include/compat/torch/csrc/api/include/"
122- )
138+ "${CUDA_INCLUDE_DIRS} "
139+ "${CUDA_INCLUDE_DIRS} /cccl" )
123140
124141set (PADDLE_LIBRARIES
125142 "${PADDLE_DIR} /base/libpaddle.so" "${PADDLE_DIR} /libs/libcommon.so"
0 commit comments