Skip to content

Commit 47e652e

Browse files
committed
Change _CUDAX_GET_DRIVER_FUNCTION and minor fixes
1 parent 97ebb96 commit 47e652e

File tree

2 files changed

+51
-42
lines changed

2 files changed

+51
-42
lines changed

cudax/include/cuda/experimental/__driver/driver_api.cuh

Lines changed: 50 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,18 @@
2727
# include <cuda/std/cstddef>
2828

2929
# include <cuda.h>
30+
# include <cudaTypedefs.h>
3031

3132
# include <cuda/std/__cccl/prologue.h>
3233

33-
// Redefine the macros locally — they are #undef'd at the end of driver_api.h.
34-
# define _CUDAX_GET_DRIVER_FUNCTION(function_name) \
35-
reinterpret_cast<decltype(::function_name)*>(::cuda::__driver::__get_driver_entry_point(#function_name))
36-
37-
# define _CUDAX_GET_DRIVER_FUNCTION_VERSIONED(function_name, versioned_fn_name, major, minor) \
38-
reinterpret_cast<decltype(::versioned_fn_name)*>( \
39-
::cuda::__driver::__get_driver_entry_point(#function_name, major, minor))
34+
// Get a driver function pointer, casting to the PFN typedef for type safety.
35+
// Uses PFN_ typedefs from cudaTypedefs.h to avoid ABI mismatches caused by
36+
// #define'd version aliases in cuda.h (e.g. #define cuFoo cuFoo_v2).
37+
// The ## operator suppresses macro expansion of the function name, so this is
38+
// safe even for names that are #define'd to versioned variants.
39+
# define _CUDAX_GET_DRIVER_FUNCTION(pfn_name, major, minor) \
40+
reinterpret_cast<PFN_##pfn_name##_v##major##0##minor##0>( \
41+
::cuda::__driver::__get_driver_entry_point(#pfn_name, major, minor))
4042

4143
namespace cuda::experimental::__driver
4244
{
@@ -53,7 +55,7 @@ namespace cuda::experimental::__driver
5355
::cuda::std::size_t __width,
5456
::cuda::std::size_t __height)
5557
{
56-
static auto __driver_fn = _CUDAX_GET_DRIVER_FUNCTION(cuGraphAddMemsetNode);
58+
static auto __driver_fn = _CUDAX_GET_DRIVER_FUNCTION(cuGraphAddMemsetNode, 10, 0);
5759
::CUgraphNode __node{};
5860
::CUDA_MEMSET_NODE_PARAMS __params{};
5961
__params.dst = __dst;
@@ -78,7 +80,7 @@ namespace cuda::experimental::__driver
7880
::CUdeviceptr __src,
7981
::cuda::std::size_t __byte_count)
8082
{
81-
static auto __driver_fn = _CUDAX_GET_DRIVER_FUNCTION(cuGraphAddMemcpyNode);
83+
static auto __driver_fn = _CUDAX_GET_DRIVER_FUNCTION(cuGraphAddMemcpyNode, 10, 0);
8284
::CUgraphNode __node{};
8385
::CUDA_MEMCPY3D __params{};
8486
__params.srcMemoryType = ::CU_MEMORYTYPE_UNIFIED;
@@ -99,7 +101,7 @@ namespace cuda::experimental::__driver
99101
[[nodiscard]] _CCCL_HOST_API inline ::CUgraphNode __graphAddHostNode(
100102
::CUgraph __graph, const ::CUgraphNode* __deps, ::cuda::std::size_t __ndeps, ::CUhostFn __fn, void* __user_data)
101103
{
102-
static auto __driver_fn = _CUDAX_GET_DRIVER_FUNCTION(cuGraphAddHostNode);
104+
static auto __driver_fn = _CUDAX_GET_DRIVER_FUNCTION(cuGraphAddHostNode, 10, 0);
103105
::CUgraphNode __node{};
104106
::CUDA_HOST_NODE_PARAMS __params{};
105107
__params.fn = __fn;
@@ -114,7 +116,7 @@ namespace cuda::experimental::__driver
114116
[[nodiscard]] _CCCL_HOST_API inline ::CUgraphNode __graphAddChildGraphNode(
115117
::CUgraph __graph, const ::CUgraphNode* __deps, ::cuda::std::size_t __ndeps, ::CUgraph __child_graph)
116118
{
117-
static auto __driver_fn = _CUDAX_GET_DRIVER_FUNCTION(cuGraphAddChildGraphNode);
119+
static auto __driver_fn = _CUDAX_GET_DRIVER_FUNCTION(cuGraphAddChildGraphNode, 10, 0);
118120
::CUgraphNode __node{};
119121
::cuda::__driver::__call_driver_fn(
120122
__driver_fn, "Failed to add a child graph node", &__node, __graph, __deps, __ndeps, __child_graph);
@@ -124,24 +126,24 @@ namespace cuda::experimental::__driver
124126
// ── Graph: event record node ────────────────────────────────────────────────
125127

126128
[[nodiscard]] _CCCL_HOST_API inline ::CUgraphNode __graphAddEventRecordNode(
127-
::CUgraph __graph, const ::CUgraphNode* __deps, ::cuda::std::size_t __ndeps, ::CUevent __event)
129+
::CUgraph __graph, const ::CUgraphNode* __deps, ::cuda::std::size_t __ndeps, ::CUevent __ev)
128130
{
129-
static auto __driver_fn = _CUDAX_GET_DRIVER_FUNCTION(cuGraphAddEventRecordNode);
131+
static auto __driver_fn = _CUDAX_GET_DRIVER_FUNCTION(cuGraphAddEventRecordNode, 11, 1);
130132
::CUgraphNode __node{};
131133
::cuda::__driver::__call_driver_fn(
132-
__driver_fn, "Failed to add an event record node to graph", &__node, __graph, __deps, __ndeps, __event);
134+
__driver_fn, "Failed to add an event record node to graph", &__node, __graph, __deps, __ndeps, __ev);
133135
return __node;
134136
}
135137

136138
// ── Graph: event wait node ──────────────────────────────────────────────────
137139

138140
[[nodiscard]] _CCCL_HOST_API inline ::CUgraphNode
139-
__graphAddEventWaitNode(::CUgraph __graph, const ::CUgraphNode* __deps, ::cuda::std::size_t __ndeps, ::CUevent __event)
141+
__graphAddEventWaitNode(::CUgraph __graph, const ::CUgraphNode* __deps, ::cuda::std::size_t __ndeps, ::CUevent __ev)
140142
{
141-
static auto __driver_fn = _CUDAX_GET_DRIVER_FUNCTION(cuGraphAddEventWaitNode);
143+
static auto __driver_fn = _CUDAX_GET_DRIVER_FUNCTION(cuGraphAddEventWaitNode, 11, 1);
142144
::CUgraphNode __node{};
143145
::cuda::__driver::__call_driver_fn(
144-
__driver_fn, "Failed to add an event wait node to graph", &__node, __graph, __deps, __ndeps, __event);
146+
__driver_fn, "Failed to add an event wait node to graph", &__node, __graph, __deps, __ndeps, __ev);
145147
return __node;
146148
}
147149

@@ -152,7 +154,7 @@ __graphAddEventWaitNode(::CUgraph __graph, const ::CUgraphNode* __deps, ::cuda::
152154
[[nodiscard]] _CCCL_HOST_API inline ::CUgraphConditionalHandle
153155
__graphConditionalHandleCreate(::CUgraph __graph, unsigned int __default_val, unsigned int __flags)
154156
{
155-
static auto __driver_fn = _CUDAX_GET_DRIVER_FUNCTION(cuGraphConditionalHandleCreate);
157+
static auto __driver_fn = _CUDAX_GET_DRIVER_FUNCTION(cuGraphConditionalHandleCreate, 12, 3);
156158
::CUgraphConditionalHandle __handle{};
157159
::CUcontext __ctx = ::cuda::__driver::__ctxGetCurrent();
158160
::cuda::__driver::__call_driver_fn(
@@ -165,7 +167,7 @@ __graphConditionalHandleCreate(::CUgraph __graph, unsigned int __default_val, un
165167
[[nodiscard]] _CCCL_HOST_API inline ::CUgraphNode __graphAddNode(
166168
::CUgraph __graph, const ::CUgraphNode* __deps, ::cuda::std::size_t __ndeps, ::CUgraphNodeParams* __params)
167169
{
168-
static auto __driver_fn = _CUDAX_GET_DRIVER_FUNCTION(cuGraphAddNode);
170+
static auto __driver_fn = _CUDAX_GET_DRIVER_FUNCTION(cuGraphAddNode, 12, 2);
169171
::CUgraphNode __node{};
170172
::cuda::__driver::__call_driver_fn(
171173
__driver_fn, "Failed to add a node to graph", &__node, __graph, __deps, __ndeps, __params);
@@ -178,7 +180,7 @@ __graphConditionalHandleCreate(::CUgraph __graph, unsigned int __default_val, un
178180

179181
[[nodiscard]] _CCCL_HOST_API inline ::CUgraph __graphCreate()
180182
{
181-
static auto __driver_fn = _CUDAX_GET_DRIVER_FUNCTION(cuGraphCreate);
183+
static auto __driver_fn = _CUDAX_GET_DRIVER_FUNCTION(cuGraphCreate, 10, 0);
182184
::CUgraph __graph{};
183185
::cuda::__driver::__call_driver_fn(__driver_fn, "Failed to create graph", &__graph, 0u);
184186
return __graph;
@@ -188,15 +190,15 @@ __graphConditionalHandleCreate(::CUgraph __graph, unsigned int __default_val, un
188190

189191
[[nodiscard]] _CCCL_HOST_API inline ::cudaError_t __graphDestroyNoThrow(::CUgraph __graph) noexcept
190192
{
191-
static auto __driver_fn = _CUDAX_GET_DRIVER_FUNCTION(cuGraphDestroy);
193+
static auto __driver_fn = _CUDAX_GET_DRIVER_FUNCTION(cuGraphDestroy, 10, 0);
192194
return static_cast<::cudaError_t>(__driver_fn(__graph));
193195
}
194196

195197
// ── Graph: clone ────────────────────────────────────────────────────────────
196198

197199
[[nodiscard]] _CCCL_HOST_API inline ::CUgraph __graphClone(::CUgraph __original)
198200
{
199-
static auto __driver_fn = _CUDAX_GET_DRIVER_FUNCTION(cuGraphClone);
201+
static auto __driver_fn = _CUDAX_GET_DRIVER_FUNCTION(cuGraphClone, 10, 0);
200202
::CUgraph __clone{};
201203
::cuda::__driver::__call_driver_fn(__driver_fn, "Failed to clone graph", &__clone, __original);
202204
return __clone;
@@ -206,7 +208,7 @@ __graphConditionalHandleCreate(::CUgraph __graph, unsigned int __default_val, un
206208

207209
[[nodiscard]] _CCCL_HOST_API inline ::cuda::std::size_t __graphGetNodeCount(::CUgraph __graph)
208210
{
209-
static auto __driver_fn = _CUDAX_GET_DRIVER_FUNCTION(cuGraphGetNodes);
211+
static auto __driver_fn = _CUDAX_GET_DRIVER_FUNCTION(cuGraphGetNodes, 10, 0);
210212
::cuda::std::size_t __count = 0;
211213
::cuda::__driver::__call_driver_fn(__driver_fn, "Failed to get graph node count", __graph, nullptr, &__count);
212214
return __count;
@@ -216,7 +218,7 @@ __graphConditionalHandleCreate(::CUgraph __graph, unsigned int __default_val, un
216218

217219
[[nodiscard]] _CCCL_HOST_API inline ::CUgraphExec __graphInstantiate(::CUgraph __graph, unsigned long long __flags = 0)
218220
{
219-
static auto __driver_fn = _CUDAX_GET_DRIVER_FUNCTION(cuGraphInstantiate);
221+
static auto __driver_fn = _CUDAX_GET_DRIVER_FUNCTION(cuGraphInstantiateWithFlags, 11, 4);
220222
::CUgraphExec __exec{};
221223
::cuda::__driver::__call_driver_fn(__driver_fn, "Failed to instantiate graph", &__exec, __graph, __flags);
222224
return __exec;
@@ -226,15 +228,15 @@ __graphConditionalHandleCreate(::CUgraph __graph, unsigned int __default_val, un
226228

227229
_CCCL_HOST_API inline void __graphLaunch(::CUgraphExec __exec, ::CUstream __stream)
228230
{
229-
static auto __driver_fn = _CUDAX_GET_DRIVER_FUNCTION(cuGraphLaunch);
231+
static auto __driver_fn = _CUDAX_GET_DRIVER_FUNCTION(cuGraphLaunch, 10, 0);
230232
::cuda::__driver::__call_driver_fn(__driver_fn, "Failed to launch graph", __exec, __stream);
231233
}
232234

233235
// ── Graph exec: destroy (no-throw, for use in destructors) ──────────────────
234236

235237
[[nodiscard]] _CCCL_HOST_API inline ::cudaError_t __graphExecDestroyNoThrow(::CUgraphExec __exec) noexcept
236238
{
237-
static auto __driver_fn = _CUDAX_GET_DRIVER_FUNCTION(cuGraphExecDestroy);
239+
static auto __driver_fn = _CUDAX_GET_DRIVER_FUNCTION(cuGraphExecDestroy, 10, 0);
238240
return static_cast<::cudaError_t>(__driver_fn(__exec));
239241
}
240242

@@ -243,7 +245,7 @@ _CCCL_HOST_API inline void __graphLaunch(::CUgraphExec __exec, ::CUstream __stre
243245
[[nodiscard]] _CCCL_HOST_API inline ::CUgraphNode
244246
__graphAddEmptyNode(::CUgraph __graph, const ::CUgraphNode* __deps, ::cuda::std::size_t __ndeps)
245247
{
246-
static auto __driver_fn = _CUDAX_GET_DRIVER_FUNCTION(cuGraphAddEmptyNode);
248+
static auto __driver_fn = _CUDAX_GET_DRIVER_FUNCTION(cuGraphAddEmptyNode, 10, 0);
247249
::CUgraphNode __node{};
248250
::cuda::__driver::__call_driver_fn(
249251
__driver_fn, "Failed to add an empty node to graph", &__node, __graph, __deps, __ndeps);
@@ -256,26 +258,32 @@ _CCCL_HOST_API inline void __graphAddDependencies(
256258
::CUgraph __graph,
257259
const ::CUgraphNode* __from,
258260
const ::CUgraphNode* __to,
259-
::cuda::std::size_t __ndeps,
260-
const ::CUgraphEdgeData* __edge_data = nullptr)
261+
::cuda::std::size_t __ndeps)
261262
{
263+
static auto __driver_fn = _CUDAX_GET_DRIVER_FUNCTION(cuGraphAddDependencies, 10, 0);
264+
::cuda::__driver::__call_driver_fn(__driver_fn, "Failed to add graph dependencies", __graph, __from, __to, __ndeps);
265+
}
266+
262267
# if _CCCL_CTK_AT_LEAST(12, 3)
268+
_CCCL_HOST_API inline void __graphAddDependencies(
269+
::CUgraph __graph,
270+
const ::CUgraphNode* __from,
271+
const ::CUgraphNode* __to,
272+
::cuda::std::size_t __ndeps,
273+
const ::CUgraphEdgeData* __edge_data)
274+
{
263275
static auto __driver_fn =
264-
_CUDAX_GET_DRIVER_FUNCTION_VERSIONED(cuGraphAddDependencies, cuGraphAddDependencies_v2, 12, 3);
276+
_CUDAX_GET_DRIVER_FUNCTION(cuGraphAddDependencies, 12, 3);
265277
::cuda::__driver::__call_driver_fn(
266278
__driver_fn, "Failed to add graph dependencies", __graph, __from, __to, __edge_data, __ndeps);
267-
# else
268-
_CCCL_ASSERT(__edge_data == nullptr, "Edge data requires CUDA Toolkit 12.3 or later");
269-
static auto __driver_fn = _CUDAX_GET_DRIVER_FUNCTION(cuGraphAddDependencies);
270-
::cuda::__driver::__call_driver_fn(__driver_fn, "Failed to add graph dependencies", __graph, __from, __to, __ndeps);
271-
# endif
272279
}
280+
# endif // _CCCL_CTK_AT_LEAST(12, 3)
273281

274282
// ── Graph node: get type ────────────────────────────────────────────────────
275283

276284
[[nodiscard]] _CCCL_HOST_API inline ::CUgraphNodeType __graphNodeGetType(::CUgraphNode __node)
277285
{
278-
static auto __driver_fn = _CUDAX_GET_DRIVER_FUNCTION(cuGraphNodeGetType);
286+
static auto __driver_fn = _CUDAX_GET_DRIVER_FUNCTION(cuGraphNodeGetType, 10, 0);
279287
::CUgraphNodeType __type{};
280288
::cuda::__driver::__call_driver_fn(__driver_fn, "Failed to get graph node type", __node, &__type);
281289
return __type;
@@ -292,7 +300,8 @@ _CCCL_HOST_API inline void __streamBeginCaptureToGraph(
292300
::cuda::std::size_t __ndeps,
293301
::CUstreamCaptureMode __mode)
294302
{
295-
static auto __driver_fn = _CUDAX_GET_DRIVER_FUNCTION(cuStreamBeginCaptureToGraph);
303+
static auto __driver_fn =
304+
_CUDAX_GET_DRIVER_FUNCTION(cuStreamBeginCaptureToGraph, 12, 3);
296305
::cuda::__driver::__call_driver_fn(
297306
__driver_fn, "Failed to begin stream capture to graph", __stream, __graph, __deps, nullptr, __ndeps, __mode);
298307
}
@@ -313,7 +322,7 @@ __streamGetCaptureInfo(::CUstream __stream, const ::CUgraphEdgeData** __edge_dat
313322
__stream_capture_info __info{};
314323
# if _CCCL_CTK_AT_LEAST(12, 4)
315324
static auto __driver_fn =
316-
_CUDAX_GET_DRIVER_FUNCTION_VERSIONED(cuStreamGetCaptureInfo, cuStreamGetCaptureInfo_v3, 12, 4);
325+
_CUDAX_GET_DRIVER_FUNCTION(cuStreamGetCaptureInfo, 12, 3);
317326
::cuda::__driver::__call_driver_fn(
318327
__driver_fn,
319328
"Failed to get stream capture info",
@@ -327,7 +336,8 @@ __streamGetCaptureInfo(::CUstream __stream, const ::CUgraphEdgeData** __edge_dat
327336
# else
328337
_CCCL_ASSERT(__edge_data_out == nullptr, "Edge data requires CUDA Toolkit 12.4 or later");
329338
__info.__edge_data = nullptr;
330-
static auto __driver_fn = _CUDAX_GET_DRIVER_FUNCTION(cuStreamGetCaptureInfo);
339+
static auto __driver_fn =
340+
_CUDAX_GET_DRIVER_FUNCTION(cuStreamGetCaptureInfo, 11, 3);
331341
::cuda::__driver::__call_driver_fn(
332342
__driver_fn,
333343
"Failed to get stream capture info",
@@ -345,15 +355,14 @@ __streamGetCaptureInfo(::CUstream __stream, const ::CUgraphEdgeData** __edge_dat
345355

346356
_CCCL_HOST_API inline void __streamEndCapture(::CUstream __stream, ::CUgraph* __graph_out)
347357
{
348-
static auto __driver_fn = _CUDAX_GET_DRIVER_FUNCTION(cuStreamEndCapture);
358+
static auto __driver_fn = _CUDAX_GET_DRIVER_FUNCTION(cuStreamEndCapture, 10, 0);
349359
::cuda::__driver::__call_driver_fn(__driver_fn, "Failed to end stream capture", __stream, __graph_out);
350360
}
351361

352362
# endif // _CCCL_CTK_AT_LEAST(12, 3)
353363
} // namespace cuda::experimental::__driver
354364

355365
# undef _CUDAX_GET_DRIVER_FUNCTION
356-
# undef _CUDAX_GET_DRIVER_FUNCTION_VERSIONED
357366

358367
# include <cuda/std/__cccl/epilogue.h>
359368

libcudacxx/include/cuda/__driver/driver_api.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ _CCCL_SUPPRESS_DEPRECATED_POP
9595
//! @brief Makes the driver version from major and minor version.
9696
[[nodiscard]] _CCCL_HOST_API constexpr int __make_version(int __major, int __minor) noexcept
9797
{
98-
_CCCL_ASSERT(__major >= 12, "invalid major CUDA Driver version");
98+
_CCCL_ASSERT(__major >= 2, "invalid major CUDA Driver version");
9999
_CCCL_ASSERT(__minor >= 0 && __minor < 100, "invalid minor CUDA Driver version");
100100
return __major * 1000 + __minor * 10;
101101
}

0 commit comments

Comments
 (0)