Skip to content

Commit b67cee6

Browse files
committed
Support Nvidia Hopper GPUs
1 parent d046063 commit b67cee6

File tree

3 files changed

+28
-10
lines changed

3 files changed

+28
-10
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ CUDAExt = "CUDA"
2828
CairoMakieExt = "CairoMakie"
2929

3030
[compat]
31-
CUDA = "3.8.4, 3.12, 4.4"
31+
CUDA = "3.8.4, 3.12, 4.4, 5"
3232
CairoMakie = "0.7, 0.10.7"
3333
CpuId = "0.3"
3434
DocStringExtensions = "0.9"

ext/CUDAExt/implementations/peakflops_gpu.jl

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ function _theoretical_peakflops_gpu_cudacores(; device, dtype)
5151
elseif dtype == Float64
5252
max_peakflops *= 1
5353
else
54-
throw(ArgumentError("Unsupported dtype."))
54+
throw(ArgumentError("Unsupported dtype $(dtype)."))
5555
end
5656
return max_peakflops
5757
end
@@ -60,7 +60,9 @@ function _theoretical_peakflops_gpu_tensorcores(;
6060
device=CUDA.device(), dtype=Float16, verbose=true
6161
)
6262
cap = CUDA.capability(device)
63-
if cap == v"8.0.0"
63+
if cap == v"9.0.0"
64+
devtype = :Hopper
65+
elseif cap == v"8.0.0"
6466
devtype = :A100
6567
elseif cap == v"7.0.0"
6668
devtype = :V100
@@ -70,10 +72,26 @@ function _theoretical_peakflops_gpu_tensorcores(;
7072
max_clock_rate = CUDA.attribute(device, CUDA.CU_DEVICE_ATTRIBUTE_CLOCK_RATE) # in kHz
7173
num_tensor_cores = ntensorcores(device)
7274
max_peakflops = max_clock_rate * num_tensor_cores * 1e-9 # in TFLOP/s
73-
if devtype == :A100
75+
if devtype == :Hopper
76+
# matrix dimensions 8x8x4, factor 2 for nflops in A*B+C see
77+
# * <https://resources.nvidia.com/en-us-tensor-core/gtc22-whitepaper-hopper> (figures 10-11)
78+
# * <https://developer.nvidia.com/blog/nvidia-hopper-architecture-in-depth/> (figures 5-8)
7479
if Symbol(dtype) == :Float16
75-
# matrix dimensions 8x8x4, factor 2 for nflops in A*B+C
76-
# see e.g. https://peerj.com/articles/cs-330.pdf
80+
max_peakflops *= 2 * 16 * 8 * 4 # XXX: Wrong result!
81+
elseif Symbol(dtype) in (:Float32, :TensorFloat32, :TF32)
82+
max_peakflops *= 2 * 8 * 8 * 4 # XXX: Wrong result!
83+
elseif Symbol(dtype) == :Float64
84+
max_peakflops *= 2 * 4 * 4 * 2
85+
elseif Symbol(dtype) == :Int8
86+
max_peakflops *= 2 * 2 * 32 * 8 * 4 # XXX: Wrong result!
87+
else
88+
throw(ArgumentError("Unsupported dtype $(dtype)."))
89+
end
90+
elseif devtype == :A100
91+
if Symbol(dtype) == :Float16
92+
# matrix dimensions 8x8x4, factor 2 for nflops in A*B+C see
93+
# e.g. <https://doi.org/10.7717/peerj-cs.330> or
94+
# <https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/nvidia-ampere-architecture-whitepaper.pdf>
7795
max_peakflops *= 2 * 8 * 8 * 4
7896
elseif Symbol(dtype) in (:Float32, :TensorFloat32, :TF32)
7997
max_peakflops *= 2 * 4 * 8 * 4
@@ -82,13 +100,13 @@ function _theoretical_peakflops_gpu_tensorcores(;
82100
elseif Symbol(dtype) == :Int8
83101
max_peakflops *= 2 * 2 * 8 * 8 * 4
84102
else
85-
throw(ArgumentError("Unsupported dtype."))
103+
throw(ArgumentError("Unsupported dtype $(dtype)."))
86104
end
87105
elseif devtype == :V100
88106
if Symbol(dtype) == :Float16
89107
max_peakflops *= 2 * 4 * 4 * 4
90108
else
91-
throw(ArgumentError("Unsupported dtype."))
109+
throw(ArgumentError("Unsupported dtype $(dtype)."))
92110
end
93111
end
94112
return max_peakflops

ext/CUDAExt/peakflops_gpu_wmmas.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ function _peakflops_gpu_wmmas(;
147147
dtype_a = dtype_b = BFloat16
148148
dtype_c = dtype_d = Float32
149149
else
150-
throw(ArgumentError("Unsupported dtype."))
150+
throw(ArgumentError("Unsupported dtype $(dtype)."))
151151
end
152152
d_a = CUDA.rand(dtype_a, m, k)
153153
d_b = CUDA.rand(dtype_b, k, n)
@@ -165,7 +165,7 @@ function _peakflops_gpu_wmmas(;
165165
elseif Symbol(dtype) in (:BFloat16, :BF16)
166166
kernel = @cuda launch = false _kernel_wmma_bf16_lowlevel(d_a, d_b, d_c, d_d)
167167
else
168-
throw(ArgumentError("Unsupported dtype."))
168+
throw(ArgumentError("Unsupported dtype $(dtype)."))
169169
end
170170
warpsize = CUDA.attribute(device, CUDA.CU_DEVICE_ATTRIBUTE_WARP_SIZE)
171171
# @show threads

0 commit comments

Comments
 (0)