@@ -20,22 +20,22 @@ for (TA, tA) in [(:CuVecOrMat, 'N'), (:CTranspose, 'T')]
2020 (:TropicalMinPlusF32 , :Cfloat , :FLOAT_minplus , :lib_TropicalMinPlus_FP32 ), (:TropicalMinPlusF64 , :Cdouble , :DOUBLE_minplus , :lib_TropicalMinPlus_FP64 ),
2121 (:TropicalMaxMulF32 , :Cfloat , :FLOAT_maxmul , :lib_TropicalMaxMul_FP32 ), (:TropicalMaxMulF64 , :Cdouble , :DOUBLE_maxmul , :lib_TropicalMaxMul_FP64 ), (:TropicalMaxMulI32 , :Cint , :INT_maxmul , :lib_TropicalMaxMul_INT32 ), (:TropicalMaxMulI64 , :Clong , :LONG_maxmul , :lib_TropicalMaxMul_INT64 )
2222 ]
23- @eval function matmul! (C:: CuVecOrMat{T} , A:: $TA{T} , B:: $TB{T} , α:: T , β:: T ) where {T<: $TT }
23+ @eval function matmul! (C:: CuVecOrMat{T} , A:: $TA{T} , B:: $TB{T} , α:: T , β:: T , stream :: CuStream = stream () ) where {T<: $TT }
2424 M, N, K = dims_match (A, B, C)
2525 if K == 0 && M * N != 0
2626 return rmul! (C, β)
2727 elseif M * N == 0
2828 return C
2929 else
30- @ccall $ lib.$ funcname (M:: Cint , N:: Cint , K:: Cint , pointer (parent (A)):: CuPtr{$CT} , pointer (parent (B)):: CuPtr{$CT} , pointer (C):: CuPtr{$CT} , content (α):: $CT , content (β):: $CT , $ tA:: Cchar , $ tB:: Cchar ):: Cvoid
30+ @ccall $ lib.$ funcname (M:: Cint , N:: Cint , K:: Cint , pointer (parent (A)):: CuPtr{$CT} , pointer (parent (B)):: CuPtr{$CT} , pointer (C):: CuPtr{$CT} , content (α):: $CT , content (β):: $CT , $ tA:: Cchar , $ tB:: Cchar , stream :: CUDA.CUstream ):: Cvoid
3131 end
3232 return C
3333 end
3434 end
3535 end
3636end
3737
38- const CuTropicalBlasTypes = Union{TropicalAndOr, TropicalMaxPlusF32, TropicalMaxPlusF64, TropicalMaxMulF32, TropicalMaxMulF64, TropicalMaxMulI32, TropicalMaxMulI64}
38+ const CuTropicalBlasTypes = Union{TropicalAndOr, TropicalMaxPlusF32, TropicalMaxPlusF64, TropicalMinPlusF32, TropicalMinPlusF64, TropicalMaxMulF32, TropicalMaxMulF64, TropicalMaxMulI32, TropicalMaxMulI64}
3939
4040# overload the LinearAlgebra.mul!
4141for TA in [:CuVecOrMat , :CTranspose ]
0 commit comments