Skip to content

Commit ee26ad8

Browse files
authored
Add support of CuStream control (#27)
* add stream control * reexport TropicalNumbers * update deps
1 parent dfa895f commit ee26ad8

File tree

6 files changed

+21
-18
lines changed

6 files changed

+21
-18
lines changed

Project.toml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
11
name = "CuTropicalGEMM"
22
uuid = "c2b282c3-c9c2-431d-80f7-a1a0561ebe55"
3-
authors = ["Xuanzhao Gao <[email protected]> and contributors"]
4-
version = "0.1.1"
3+
authors = ["Xuanzhao Gao <[email protected]> and Jin-Guo Liu"]
4+
version = "0.1.2"
55

66
[deps]
77
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
88
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
9+
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
910
TropicalGemmC_jll = "4f4992fb-2984-5eba-87b8-475305d0f5fc"
1011
TropicalNumbers = "b3a74e9c-7526-4576-a4eb-79c0d4c32334"
1112

1213
[compat]
1314
CUDA = "5"
14-
TropicalGemmC_jll = "0.1.1"
15+
TropicalGemmC_jll = "0.1.3"
1516
TropicalNumbers = "0.6.2"
17+
Reexport = "1.2.2"
1618
julia = "1"
1719

1820
[extras]

src/CuTropicalGEMM.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
11
module CuTropicalGEMM
22

3-
using CUDA, TropicalNumbers, LinearAlgebra, TropicalGemmC_jll
3+
using CUDA, LinearAlgebra
4+
using TropicalGemmC_jll
5+
using Reexport
6+
@reexport using TropicalNumbers
7+
48
export matmul!
59

610
function __init__()
711
if CUDA.functional() == true
812
if CUDA.driver_version() < v"11.4"
913
@warn "CUDA.driver_version < v11.4! CuTropicalGEMM may not be available."
10-
elseif CUDA.driver_version() > v"12.2"
11-
@warn "CUDA.driver_version > v12.2! CuTropicalGEMM may not be available."
14+
elseif CUDA.driver_version() > v"12.3"
15+
@warn "CUDA.driver_version > v12.3! CuTropicalGEMM may not be available."
1216
end
1317
elseif CUDA.functional() == false
1418
@warn "CUDA Driver not found! CuTropicalGEMM will not be available."

src/tropical_gemms.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,22 +20,22 @@ for (TA, tA) in [(:CuVecOrMat, 'N'), (:CTranspose, 'T')]
2020
(:TropicalMinPlusF32, :Cfloat, :FLOAT_minplus, :lib_TropicalMinPlus_FP32), (:TropicalMinPlusF64, :Cdouble, :DOUBLE_minplus, :lib_TropicalMinPlus_FP64),
2121
(:TropicalMaxMulF32, :Cfloat, :FLOAT_maxmul, :lib_TropicalMaxMul_FP32), (:TropicalMaxMulF64, :Cdouble, :DOUBLE_maxmul, :lib_TropicalMaxMul_FP64), (:TropicalMaxMulI32, :Cint, :INT_maxmul, :lib_TropicalMaxMul_INT32), (:TropicalMaxMulI64, :Clong, :LONG_maxmul, :lib_TropicalMaxMul_INT64)
2222
]
23-
@eval function matmul!(C::CuVecOrMat{T}, A::$TA{T}, B::$TB{T}, α::T, β::T) where {T<:$TT}
23+
@eval function matmul!(C::CuVecOrMat{T}, A::$TA{T}, B::$TB{T}, α::T, β::T, stream::CuStream = stream()) where {T<:$TT}
2424
M, N, K = dims_match(A, B, C)
2525
if K == 0 && M * N != 0
2626
return rmul!(C, β)
2727
elseif M * N == 0
2828
return C
2929
else
30-
@ccall $lib.$funcname(M::Cint, N::Cint, K::Cint, pointer(parent(A))::CuPtr{$CT}, pointer(parent(B))::CuPtr{$CT}, pointer(C)::CuPtr{$CT}, content(α)::$CT, content(β)::$CT, $tA::Cchar, $tB::Cchar)::Cvoid
30+
@ccall $lib.$funcname(M::Cint, N::Cint, K::Cint, pointer(parent(A))::CuPtr{$CT}, pointer(parent(B))::CuPtr{$CT}, pointer(C)::CuPtr{$CT}, content(α)::$CT, content(β)::$CT, $tA::Cchar, $tB::Cchar, stream::CUDA.CUstream)::Cvoid
3131
end
3232
return C
3333
end
3434
end
3535
end
3636
end
3737

38-
const CuTropicalBlasTypes = Union{TropicalAndOr, TropicalMaxPlusF32, TropicalMaxPlusF64, TropicalMaxMulF32, TropicalMaxMulF64, TropicalMaxMulI32, TropicalMaxMulI64}
38+
const CuTropicalBlasTypes = Union{TropicalAndOr, TropicalMaxPlusF32, TropicalMaxPlusF64, TropicalMinPlusF32, TropicalMinPlusF64, TropicalMaxMulF32, TropicalMaxMulF64, TropicalMaxMulI32, TropicalMaxMulI64}
3939

4040
# overload the LinearAlgebra.mul!
4141
for TA in [:CuVecOrMat, :CTranspose]

test/Project.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
[deps]
22
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
33
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
4-
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
5-
TropicalGemmC_jll = "4f4992fb-2984-5eba-87b8-475305d0f5fc"
6-
TropicalNumbers = "b3a74e9c-7526-4576-a4eb-79c0d4c32334"
4+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

test/runtests.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
using CuTropicalGEMM
22
using Test
33
using CUDA
4-
using TropicalNumbers
54
using LinearAlgebra
65

76
@testset "CuTropicalGEMM.jl" begin

test/tropical_gemms.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
hB = Array(B)
3131
hC = Array(C)
3232

33-
C = CuTropicalGEMM.matmul!(C, A, B, α, β)
33+
CUDA.@sync C = CuTropicalGEMM.matmul!(C, A, B, α, β)
3434

3535
hC .= α .* hA * hB .+ β .* hC
3636

@@ -55,8 +55,8 @@ end
5555
@testset "$testname" begin
5656
if !(size(A) == (1,4) && size(B) == (4,))
5757
res0 = Array(A) * Array(B)
58-
res1 = A * B
59-
res2 = LinearAlgebra.mul!(MT.(CUDA.zeros(T, size(res0)...)), A, B)
58+
CUDA.@sync res1 = A * B
59+
CUDA.@sync res2 = LinearAlgebra.mul!(MT.(CUDA.zeros(T, size(res0)...)), A, B)
6060
@test Array(res1) res0
6161
@test Array(res2) res0
6262
end
@@ -79,8 +79,8 @@ end
7979
for B in [transpose(a), a, b]
8080
if !(size(A) == (1,4) && size(B) == (4,))
8181
res0 = Array(A) * Array(B)
82-
res1 = A * B
83-
res2 = LinearAlgebra.mul!(MT.(CUDA.zeros(T, size(res0)...)), A, B, true, false)
82+
CUDA.@sync res1 = A * B
83+
CUDA.@sync res2 = LinearAlgebra.mul!(MT.(CUDA.zeros(T, size(res0)...)), A, B, true, false)
8484
@test Array(res1) res0
8585
@test Array(res2) res0
8686
end

0 commit comments

Comments
 (0)