Skip to content

Commit 1c10dfe

Browse files
committed
add GPU package import
1 parent 5c1d5cb commit 1c10dfe

File tree

1 file changed

+15
-12
lines changed

1 file changed

+15
-12
lines changed

src/CellArray.jl

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -145,18 +145,19 @@ See also: [`CellArray`](@ref), [`CPUCellArray`](@ref), [`ROCCellArray`](@ref)
145145
146146
See also: [`@define_ROCCellArray`](@ref)
147147
"""
148-
macro define_CuCellArray() esc(define_CuCellArray()) end
148+
macro define_CuCellArray() esc(define_CuCellArray(__module__)) end
149149

150-
function define_CuCellArray()
150+
function define_CuCellArray(caller::Module)
151+
@eval caller import CUDA # NOTE: this is required for CUDA.@cuda to work (at least when running the unit tests), which is needed for the GPUCompiler bug workaround
151152
quote
152153
const CuCellArray{T,N,B,T_elem} = CellArrays.CellArray{T,N,B,CUDA.CuArray{T_elem,CellArrays._N,CUDA.DeviceMemory}}
153154

154-
CuCellArray{T,B}(::UndefInitializer, dims::NTuple{N,Int}) where {T<:CellArrays.Cell,N,B} = ( CellArrays.check_T(T); A = CuCellArray{T,N,B,CellArrays.eltype(T)}(undef, dims); f(A)=(CellArrays.plain_flat(A); CellArrays.plain_arrayflat(A); return); if (B in (0,1)) @cuda launch=false launch=false f(A) end; A )
155+
CuCellArray{T,B}(::UndefInitializer, dims::NTuple{N,Int}) where {T<:CellArrays.Cell,N,B} = ( CellArrays.check_T(T); A = CuCellArray{T,N,B,CellArrays.eltype(T)}(undef, dims); f(A)=(CellArrays.plain_flat(A); CellArrays.plain_arrayflat(A); return); if (B in (0,1)) CUDA.@cuda launch=false f(A) end; A )
155156
CuCellArray{T,B}(::UndefInitializer, dims::Vararg{Int, N}) where {T<:CellArrays.Cell,N,B} = CuCellArray{T,B}(undef, dims)
156157
CuCellArray{T}(::UndefInitializer, dims::NTuple{N,Int}) where {T<:CellArrays.Cell,N} = CuCellArray{T,CellArrays.B0}(undef, dims)
157158
CuCellArray{T}(::UndefInitializer, dims::Vararg{Int, N}) where {T<:CellArrays.Cell,N} = CuCellArray{T}(undef, dims)
158159

159-
CuCellArray(A::CellArrays.CellArray{T,N,B,T_array}) where {T,N,B,T_array} = (A = CellArrays.CellArray{T,N,B}(CUDA.CuArray(A.data), A.dims); f(A)=(CellArrays.plain_flat(A); CellArrays.plain_arrayflat(A); return); if (B in (0,1)) @cuda launch=false f(A) end; A)
160+
CuCellArray(A::CellArrays.CellArray{T,N,B,T_array}) where {T,N,B,T_array} = (A = CellArrays.CellArray{T,N,B}(CUDA.CuArray(A.data), A.dims); f(A)=(CellArrays.plain_flat(A); CellArrays.plain_arrayflat(A); return); if (B in (0,1)) CUDA.@cuda launch=false f(A) end; A)
160161

161162
Base.show(io::IO, A::CuCellArray) = Base.show(io, CellArrays.CPUCellArray(A))
162163
Base.show(io::IO, ::MIME"text/plain", A::CuCellArray{T,N,B}) where {T,N,B} = ( println(io, "$(length(A))-element CuCellArray{$T, $N, $B, $(CellArrays.eltype(T))}:"); Base.print_array(io, CellArrays.CPUCellArray(A)) )
@@ -188,19 +189,20 @@ See also: [`CellArray`](@ref), [`CPUCellArray`](@ref), [`CuCellArray`](@ref)
188189
189190
See also: [`@define_CuCellArray`](@ref)
190191
"""
191-
macro define_ROCCellArray() esc(define_ROCCellArray()) end
192+
macro define_ROCCellArray() esc(define_ROCCellArray(__module__)) end
192193

193-
function define_ROCCellArray()
194+
function define_ROCCellArray(caller::Module)
195+
@eval caller import AMDGPU # NOTE: this is required for AMDGPU.@roc to work (at least when running the unit tests), which is needed for the GPUCompiler bug workaround
194196
quote
195197
const ROCCellArray{T,N,B,T_elem} = CellArrays.CellArray{T,N,B,AMDGPU.ROCArray{T_elem,CellArrays._N}} # TODO: ,AMDGPU.Runtime.Mem.HIPBuffer should be added here later. The moment it has no impact (and would require adaption of the unit tests).
196198
const ROCDeviceCellArray{T,N,B,T_elem} = CellArrays.CellArray{T,N,B,AMDGPU.ROCDeviceArray{T_elem,CellArrays._N,AMDGPU.Runtime.Mem.HIPBuffer}}
197199

198-
ROCCellArray{T,B}(::UndefInitializer, dims::NTuple{N,Int}) where {T<:CellArrays.Cell,N,B} = ( CellArrays.check_T(T); A = ROCCellArray{T,N,B,CellArrays.eltype(T)}(undef, dims); A ) # TODO: Once reshape is implemented in AMDGPU, the workaround can be applied as well: f(A)=(CellArrays.plain_flat(A); CellArrays.plain_arrayflat(A); return); if (B in (0,1)) @roc launch=false f(A) end; A )
200+
ROCCellArray{T,B}(::UndefInitializer, dims::NTuple{N,Int}) where {T<:CellArrays.Cell,N,B} = ( CellArrays.check_T(T); A = ROCCellArray{T,N,B,CellArrays.eltype(T)}(undef, dims); A ) # TODO: Once reshape is implemented in AMDGPU, the workaround can be applied as well: f(A)=(CellArrays.plain_flat(A); CellArrays.plain_arrayflat(A); return); if (B in (0,1)) AMDGPU.@roc launch=false f(A) end; A )
199201
ROCCellArray{T,B}(::UndefInitializer, dims::Vararg{Int, N}) where {T<:CellArrays.Cell,N,B} = ROCCellArray{T,B}(undef, dims)
200202
ROCCellArray{T}(::UndefInitializer, dims::NTuple{N,Int}) where {T<:CellArrays.Cell,N} = ROCCellArray{T,CellArrays.B0}(undef, dims)
201203
ROCCellArray{T}(::UndefInitializer, dims::Vararg{Int, N}) where {T<:CellArrays.Cell,N} = ROCCellArray{T}(undef, dims)
202204

203-
ROCCellArray(A::CellArrays.CellArray{T,N,B,T_array}) where {T,N,B,T_array} = ( A = CellArrays.CellArray{T,N,B}(AMDGPU.ROCArray(A.data), A.dims); A ) # TODO: Once reshape is implemented in AMDGPU, the workaround can be applied as well: f(A)=(CellArrays.plain_flat(A); CellArrays.plain_arrayflat(A); return); if (B in (0,1)) @roc launch=false f(A) end; A )
205+
ROCCellArray(A::CellArrays.CellArray{T,N,B,T_array}) where {T,N,B,T_array} = ( A = CellArrays.CellArray{T,N,B}(AMDGPU.ROCArray(A.data), A.dims); A ) # TODO: Once reshape is implemented in AMDGPU, the workaround can be applied as well: f(A)=(CellArrays.plain_flat(A); CellArrays.plain_arrayflat(A); return); if (B in (0,1)) AMDGPU.@roc launch=false f(A) end; A )
204206

205207
Base.show(io::IO, A::ROCCellArray) = Base.show(io, CellArrays.CPUCellArray(A))
206208
Base.show(io::IO, ::MIME"text/plain", A::ROCCellArray{T,N,B}) where {T,N,B} = ( println(io, "$(length(A))-element ROCCellArray{$T, $N, $B, $(CellArrays.eltype(T))}:"); Base.print_array(io, CellArrays.CPUCellArray(A)) )
@@ -234,19 +236,20 @@ See also: [`CellArray`](@ref), [`CPUCellArray`](@ref), [`CuCellArray`](@ref), [`
234236
235237
See also: [`@define_CuCellArray`](@ref), [`@define_ROCCellArray`](@ref)
236238
"""
237-
macro define_MtlCellArray() esc(define_MtlCellArray()) end
239+
macro define_MtlCellArray() esc(define_MtlCellArray(__module__)) end
238240

239-
function define_MtlCellArray()
241+
function define_MtlCellArray(caller::Module)
242+
@eval caller import Metal # NOTE: this is required for Metal.@metal to work (at least when running the unit tests), which is needed for the GPUCompiler bug workaround
240243
quote
241244
const MtlCellArray{T,N,B,T_elem} = CellArrays.CellArray{T,N,B,Metal.MtlArray{T_elem,CellArrays._N}}
242245
const MtlDeviceCellArray{T,N,B,T_elem} = CellArrays.CellArray{T,N,B,Metal.MtlDeviceArray{T_elem,CellArrays._N}}
243246

244-
MtlCellArray{T,B}(::UndefInitializer, dims::NTuple{N,Int}) where {T<:CellArrays.Cell,N,B} = ( CellArrays.check_T(T); A = MtlCellArray{T,N,B,CellArrays.eltype(T)}(undef, dims); A) #workaround: f(A)=(CellArrays.plain_flat(A); CellArrays.plain_arrayflat(A); return); if (B in (0,1)) @metal launch=false f(A) end; A )
247+
MtlCellArray{T,B}(::UndefInitializer, dims::NTuple{N,Int}) where {T<:CellArrays.Cell,N,B} = ( CellArrays.check_T(T); A = MtlCellArray{T,N,B,CellArrays.eltype(T)}(undef, dims); A) #workaround: f(A)=(CellArrays.plain_flat(A); CellArrays.plain_arrayflat(A); return); if (B in (0,1)) Metal.@metal launch=false f(A) end; A )
245248
MtlCellArray{T,B}(::UndefInitializer, dims::Vararg{Int, N}) where {T<:CellArrays.Cell,N,B} = MtlCellArray{T,B}(undef, dims)
246249
MtlCellArray{T}(::UndefInitializer, dims::NTuple{N,Int}) where {T<:CellArrays.Cell,N} = MtlCellArray{T,CellArrays.B0}(undef, dims)
247250
MtlCellArray{T}(::UndefInitializer, dims::Vararg{Int, N}) where {T<:CellArrays.Cell,N} = MtlCellArray{T}(undef, dims)
248251

249-
MtlCellArray(A::CellArrays.CellArray{T,N,B,T_array}) where {T,N,B,T_array} = ( A = CellArrays.CellArray{T,N,B}(Metal.MtlArray(A.data), A.dims); A) #workaround: f(A)=(CellArrays.plain_flat(A); CellArrays.plain_arrayflat(A); return); if (B in (0,1)) @metal launch=false f(A) end; A )
252+
MtlCellArray(A::CellArrays.CellArray{T,N,B,T_array}) where {T,N,B,T_array} = ( A = CellArrays.CellArray{T,N,B}(Metal.MtlArray(A.data), A.dims); A) #workaround: f(A)=(CellArrays.plain_flat(A); CellArrays.plain_arrayflat(A); return); if (B in (0,1)) Metal.@metal launch=false f(A) end; A )
250253

251254
Base.show(io::IO, A::MtlCellArray) = Base.show(io, CellArrays.CPUCellArray(A))
252255
Base.show(io::IO, ::MIME"text/plain", A::MtlCellArray{T,N,B}) where {T,N,B} = ( println(io, "$(length(A))-element MtlCellArray{$T, $N, $B, $(CellArrays.eltype(T))}:"); Base.print_array(io, CellArrays.CPUCellArray(A)) )

0 commit comments

Comments
 (0)