Skip to content

Commit 7268fa9

Browse files
CUTLASS example added, license headers added, fixes
- Add license header to each example file. - Fixed broken runs caused by type declarations. - Fixed hang in throughput.py when --run-once by doing a manual warm-up step, like in auto_throughput.py
1 parent ead9965 commit 7268fa9

File tree

10 files changed

+226
-3
lines changed

10 files changed

+226
-3
lines changed

python/examples/auto_throughput.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def as_cuda_Stream(cs: nvbench.CudaStream) -> cuda.cudadrv.driver.Stream:
2525
return cuda.external_stream(cs.addressof())
2626

2727

28-
def make_kernel(items_per_thread: int) -> cuda.compiler.AutoJitCUDAKernel:
28+
def make_kernel(items_per_thread: int) -> cuda.dispatcher.CUDADispatcher:
2929
@cuda.jit
3030
def kernel(stride: np.uintp, elements: np.uintp, in_arr, out_arr):
3131
tid = cuda.grid(1)
@@ -59,7 +59,8 @@ def throughput_bench(state: nvbench.State) -> None:
5959
krn = make_kernel(ipt)
6060

6161
# warm-up call ensures that kernel is loaded into context
62-
# before blocking kernel is launched
62+
# before blocking kernel is launched. Kernel loading may cause
63+
# a synchronization to occur.
6364
krn[blocks_in_grid, threads_per_block, alloc_stream, 0](
6465
stride, elements, inp_arr, out_arr
6566
)

python/examples/axes.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,19 @@
1+
# Copyright 2025 NVIDIA Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 with the LLVM exception
4+
# (the "License"); you may not use this file except in compliance with
5+
# the License.
6+
#
7+
# You may obtain a copy of the License at
8+
#
9+
# http://llvm.org/foundation/relicensing/LICENSE.txt
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
117
import ctypes
218
import sys
319
from typing import Dict, Optional, Tuple

python/examples/cccl_parallel_segmented_reduce.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,19 @@
1+
# Copyright 2025 NVIDIA Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 with the LLVM exception
4+
# (the "License"); you may not use this file except in compliance with
5+
# the License.
6+
#
7+
# You may obtain a copy of the License at
8+
#
9+
# http://llvm.org/foundation/relicensing/LICENSE.txt
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
117
import sys
218

319
import cuda.cccl.parallel.experimental.algorithms as algorithms

python/examples/cpu_activity.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,19 @@
1+
# Copyright 2025 NVIDIA Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 with the LLVM exception
4+
# (the "License"); you may not use this file except in compliance with
5+
# the License.
6+
#
7+
# You may obtain a copy of the License at
8+
#
9+
# http://llvm.org/foundation/relicensing/LICENSE.txt
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
117
import sys
218
import time
319

python/examples/cupy_extract.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,19 @@
1+
# Copyright 2025 NVIDIA Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 with the LLVM exception
4+
# (the "License"); you may not use this file except in compliance with
5+
# the License.
6+
#
7+
# You may obtain a copy of the License at
8+
#
9+
# http://llvm.org/foundation/relicensing/LICENSE.txt
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
117
import sys
218

319
import cuda.nvbench as nvbench

python/examples/cutlass_gemm.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Copyright 2025 NVIDIA Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 with the LLVM exception
4+
# (the "License"); you may not use this file except in compliance with
5+
# the License.
6+
#
7+
# You may obtain a copy of the License at
8+
#
9+
# http://llvm.org/foundation/relicensing/LICENSE.txt
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
18+
import sys
19+
20+
import cuda.bindings.driver as driver
21+
import cuda.core.experimental as core
22+
import cupy as cp
23+
import cutlass
24+
import numpy as np
25+
26+
import nvbench
27+
28+
29+
def as_bindings_Stream(cs: nvbench.CudaStream) -> driver.CUstream:
30+
return driver.CUstream(cs.addressof())
31+
32+
33+
def as_core_Stream(cs: nvbench.CudaStream) -> core.Stream:
34+
return core.Stream.from_handle(cs.addressof())
35+
36+
37+
def make_cp_array(arr_h: np.ndarray, dev_buf: core.Buffer, dev_id: int) -> cp.ndarray:
38+
cp_memview = cp.cuda.UnownedMemory(
39+
int(dev_buf.handle), dev_buf.size, dev_buf, dev_id
40+
)
41+
zero_offset = 0
42+
return cp.ndarray(
43+
arr_h.shape,
44+
dtype=arr_h.dtype,
45+
memptr=cp.cuda.MemoryPointer(cp_memview, zero_offset),
46+
)
47+
48+
49+
def cutlass_gemm(state: nvbench.State) -> None:
50+
n = state.get_int64("N")
51+
r = state.get_int64("R")
52+
53+
alpha = state.get_float64("alpha")
54+
55+
dt = np.float64
56+
A_h = np.random.randn(n, r).astype(dt)
57+
B_h = np.copy(A_h.mT)
58+
C_h = np.eye(n, dtype=dt)
59+
D_h = np.zeros_like(C_h)
60+
61+
if n >= 1024:
62+
# allow more time for large inputs
63+
state.set_timeout(360)
64+
65+
dev_id = state.get_device()
66+
cs = state.get_stream()
67+
s = as_bindings_Stream(cs)
68+
core_s = as_core_Stream(cs)
69+
70+
A_d = core.DeviceMemoryResource(dev_id).allocate(A_h.nbytes, core_s)
71+
B_d = core.DeviceMemoryResource(dev_id).allocate(B_h.nbytes, core_s)
72+
C_d = core.DeviceMemoryResource(dev_id).allocate(C_h.nbytes, core_s)
73+
D_d = core.DeviceMemoryResource(dev_id).allocate(D_h.nbytes, core_s)
74+
75+
driver.cuMemcpyAsync(A_d.handle, A_h.ctypes.data, A_h.nbytes, s)
76+
driver.cuMemcpyAsync(B_d.handle, B_h.ctypes.data, B_h.nbytes, s)
77+
driver.cuMemcpyAsync(C_d.handle, C_h.ctypes.data, C_h.nbytes, s)
78+
driver.cuMemcpyAsync(D_d.handle, D_h.ctypes.data, D_h.nbytes, s)
79+
80+
A_cp = make_cp_array(A_h, A_d, dev_id)
81+
B_cp = make_cp_array(B_h, B_d, dev_id)
82+
C_cp = make_cp_array(C_h, C_d, dev_id)
83+
D_cp = make_cp_array(D_h, D_d, dev_id)
84+
85+
plan = cutlass.op.Gemm(
86+
A=A_cp,
87+
B=B_cp,
88+
C=C_cp,
89+
D=D_cp,
90+
element=dt,
91+
alpha=alpha,
92+
beta=1,
93+
layout=cutlass.LayoutType.RowMajor,
94+
)
95+
# warm-up to ensure compilation is not timed
96+
plan.run(stream=s)
97+
98+
def launcher(launch: nvbench.Launch) -> None:
99+
s = as_bindings_Stream(launch.get_stream())
100+
plan.run(stream=s, sync=False)
101+
102+
state.exec(launcher)
103+
104+
105+
if __name__ == "__main__":
106+
gemm_b = nvbench.register(cutlass_gemm)
107+
gemm_b.add_int64_axis("R", [16, 64, 256])
108+
gemm_b.add_int64_axis("N", [256, 512, 1024, 2048])
109+
110+
gemm_b.add_float64_axis("alpha", [1e-2])
111+
112+
nvbench.run_all_benchmarks(sys.argv)

python/examples/exec_tag_sync.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,19 @@
1+
# Copyright 2025 NVIDIA Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 with the LLVM exception
4+
# (the "License"); you may not use this file except in compliance with
5+
# the License.
6+
#
7+
# You may obtain a copy of the License at
8+
#
9+
# http://llvm.org/foundation/relicensing/LICENSE.txt
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
117
import ctypes
218
import sys
319
from typing import Optional

python/examples/requirements.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
numpy
2+
numba
3+
cupy
4+
nvidia-cutlass
5+
cuda-cccl
6+
cuda-core
7+
cuda-bindings

python/examples/skip.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,19 @@
1+
# Copyright 2025 NVIDIA Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 with the LLVM exception
4+
# (the "License"); you may not use this file except in compliance with
5+
# the License.
6+
#
7+
# You may obtain a copy of the License at
8+
#
9+
# http://llvm.org/foundation/relicensing/LICENSE.txt
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
117
import sys
218

319
import cuda.cccl.headers as headers

python/examples/throughput.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def as_cuda_Stream(cs: nvbench.CudaStream) -> cuda.cudadrv.driver.Stream:
2525
return cuda.external_stream(cs.addressof())
2626

2727

28-
def make_kernel(items_per_thread: int) -> cuda.compiler.AutoJitCUDAKernel:
28+
def make_kernel(items_per_thread: int) -> cuda.dispatcher.CUDADispatcher:
2929
@cuda.jit
3030
def kernel(stride: np.uintp, elements: np.uintp, in_arr, out_arr):
3131
tid = cuda.grid(1)
@@ -59,6 +59,13 @@ def throughput_bench(state: nvbench.State) -> None:
5959

6060
krn = make_kernel(ipt)
6161

62+
# warm-up call ensures that kernel is loaded into context
63+
# before blocking kernel is launched. Kernel loading may
64+
# cause synchronization to occur.
65+
krn[blocks_in_grid, threads_per_block, alloc_stream, 0](
66+
stride, elements, inp_arr, out_arr
67+
)
68+
6269
def launcher(launch: nvbench.Launch):
6370
exec_stream = as_cuda_Stream(launch.get_stream())
6471
krn[blocks_in_grid, threads_per_block, exec_stream, 0](

0 commit comments

Comments
 (0)