-
Notifications
You must be signed in to change notification settings - Fork 263
Description
Context
I am working on Magma.jl, which provides Julia bindings for MAGMA, a GPU-accelerated linear algebra library. The bindings are auto-generated using Clang.jl.
MAGMA provides batched operations that perform the same operation on multiple matrices simultaneously. For example, magma_sgemm_batched computes C[i] = α*A[i]*B[i] + β*C[i] for all matrices in a batch. These functions accept arrays of device pointers (float const * const * in C), which Clang.jl maps to Ptr{Ptr{Cfloat}}:
# Auto-generated binding
function magma_sgemm_batched(transA, transB, m, n, k, alpha, dA_array, ldda, dB_array, lddb, beta, dC_array, lddc, batchCount, queue)
ccall((:magma_sgemm_batched, libmagma), Cvoid,
(magma_trans_t, magma_trans_t, magma_int_t, magma_int_t, magma_int_t,
Cfloat, Ptr{Ptr{Cfloat}}, magma_int_t,
Ptr{Ptr{Cfloat}}, magma_int_t,
Cfloat, Ptr{Ptr{Cfloat}}, magma_int_t,
magma_int_t, magma_queue_t),
transA, transB, m, n, k, alpha, dA_array, ldda, dB_array, lddb, beta, dC_array, lddc, batchCount, queue)
endTo call this function, I use CUDA.CUBLAS.unsafe_strided_batch to obtain a CuVector{CuPtr{Float32}} containing device pointers to each matrix slice:
As = CUDA.rand(Float32, 32, 32, 1024)
Bs = CUDA.rand(Float32, 32, 32, 1024)
Cs = CUDA.CuArray{Float32}(undef, 32, 32, 1024)
dA_array = CUDA.CUBLAS.unsafe_strided_batch(As) # CuVector{CuPtr{Float32}}
dB_array = CUDA.CUBLAS.unsafe_strided_batch(Bs)
dC_array = CUDA.CUBLAS.unsafe_strided_batch(Cs)
magma_sgemm_batched(..., dA_array, ..., dB_array, ..., dC_array, ...)
# ERROR: MethodError: Cannot `convert` an object of type CuArray{CuPtr{Float32}, 1, ...} to Ptr{Ptr{Float32}}Problem
There is no conversion path from CuArray{CuPtr{T}} to Ptr{Ptr{T}} for ccall so I get,
ERROR: ArgumentError: Illegal conversion of a CUDA.DeviceMemory to a Ptr{CuPtr{Float32}}
The call works if I manually change the generated bindings to use CuPtr{CuPtr{T}}, but this requires post-processing the generated code which could get quite complicated — how would we know which pointers should be CuPtrs?
Question
Are there any suggestions on how this should be handled? One option I considered is defining the following conversion in Magma.jl:
Base.unsafe_convert(::Type{Ptr{Ptr{T}}}, x::CuArray{CuPtr{T}}) where T =
Ptr{Ptr{T}}(pointer(x))(there are probably other places where different conversions might be needed, but I'm only interested in the batched methods)
However, I am unsure if this type piracy would cause any issues, or whether there is a better pattern for this use case. Would appreciate any guidance on the preferred approach.