Skip to content

Commit ff2ce2e

Browse files
committed
deactivate workaround for AMDGPU
1 parent ab4006c commit ff2ce2e

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

src/CellArray.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ function define_CuCellArray()
145145
quote
146146
const CuCellArray{T,N,B,T_elem} = CellArrays.CellArray{T,N,B,CUDA.CuArray{T_elem,CellArrays._N,CUDA.DeviceMemory}}
147147

148-
CuCellArray{T,B}(::UndefInitializer, dims::NTuple{N,Int}) where {T<:CellArrays.Cell,N,B} = (CellArrays.check_T(T); A = CuCellArray{T,N,B,CellArrays.eltype(T)}(undef, dims); f(A)=(CellArrays.plain(A); return); if (B in (0,1)) @cuda f(A) end; A)
148+
CuCellArray{T,B}(::UndefInitializer, dims::NTuple{N,Int}) where {T<:CellArrays.Cell,N,B} = ( CellArrays.check_T(T); A = CuCellArray{T,N,B,CellArrays.eltype(T)}(undef, dims); f(A)=(CellArrays.plain(A); return); if (B in (0,1)) @cuda f(A) end; A )
149149
CuCellArray{T,B}(::UndefInitializer, dims::Int...) where {T<:CellArrays.Cell,B} = CuCellArray{T,B}(undef, dims)
150150
CuCellArray{T}(::UndefInitializer, dims::NTuple{N,Int}) where {T<:CellArrays.Cell,N} = CuCellArray{T,0}(undef, dims)
151151
CuCellArray{T}(::UndefInitializer, dims::Int...) where {T<:CellArrays.Cell} = CuCellArray{T}(undef, dims)
@@ -187,13 +187,13 @@ macro define_ROCCellArray() esc(define_ROCCellArray()) end
187187
function define_ROCCellArray()
188188
quote
189189
const ROCCellArray{T,N,B,T_elem} = CellArrays.CellArray{T,N,B,AMDGPU.ROCArray{T_elem,CellArrays._N}}
190-
191-
ROCCellArray{T,B}(::UndefInitializer, dims::NTuple{N,Int}) where {T<:CellArrays.Cell,N,B} = (CellArrays.check_T(T); ROCCellArray{T,N,B,CellArrays.eltype(T)}(undef, dims))
190+
191+
ROCCellArray{T,B}(::UndefInitializer, dims::NTuple{N,Int}) where {T<:CellArrays.Cell,N,B} = ( CellArrays.check_T(T); A = ROCCellArray{T,N,B,CellArrays.eltype(T)}(undef, dims); A ) # TODO: Once reshape is implemented in AMDGPU, the workaround can be applied as well: f(A)=(CellArrays.plain(A); return); if (B in (0,1)) @roc f(A) end; A )
192192
ROCCellArray{T,B}(::UndefInitializer, dims::Int...) where {T<:CellArrays.Cell,B} = ROCCellArray{T,B}(undef, dims)
193193
ROCCellArray{T}(::UndefInitializer, dims::NTuple{N,Int}) where {T<:CellArrays.Cell,N} = ROCCellArray{T,0}(undef, dims)
194194
ROCCellArray{T}(::UndefInitializer, dims::Int...) where {T<:CellArrays.Cell} = ROCCellArray{T}(undef, dims)
195195

196-
ROCCellArray(A::CellArrays.CellArray{T,N,B,T_array}) where {T,N,B,T_array} = CellArrays.CellArray{T,N,B}(AMDGPU.ROCArray(A.data), A.dims)
196+
ROCCellArray(A::CellArrays.CellArray{T,N,B,T_array}) where {T,N,B,T_array} = ( A = CellArrays.CellArray{T,N,B}(AMDGPU.ROCArray(A.data), A.dims); A ) # TODO: Once reshape is implemented in AMDGPU, the workaround can be applied as well: f(A)=(CellArrays.plain(A); return); if (B in (0,1)) @roc f(A) end; A )
197197

198198
Base.show(io::IO, A::ROCCellArray) = Base.show(io, CellArrays.CPUCellArray(A))
199199
Base.show(io::IO, ::MIME"text/plain", A::ROCCellArray{T,N,B}) where {T,N,B} = ( println(io, "$(length(A))-element ROCCellArray{$T, $N, $B, $(CellArrays.eltype(T))}:"); Base.print_array(io, CellArrays.CPUCellArray(A)) )
@@ -231,12 +231,12 @@ function define_MtlCellArray()
231231
quote
232232
const MtlCellArray{T,N,B,T_elem} = CellArrays.CellArray{T,N,B,Metal.MtlArray{T_elem,CellArrays._N}}
233233

234-
MtlCellArray{T,B}(::UndefInitializer, dims::NTuple{N,Int}) where {T<:CellArrays.Cell,N,B} = (CellArrays.check_T(T); MtlCellArray{T,N,B,CellArrays.eltype(T)}(undef, dims))
234+
MtlCellArray{T,B}(::UndefInitializer, dims::NTuple{N,Int}) where {T<:CellArrays.Cell,N,B} = ( CellArrays.check_T(T); A = MtlCellArray{T,N,B,CellArrays.eltype(T)}(undef, dims); f(A)=(CellArrays.plain(A); return); if (B in (0,1)) @metal f(A) end; A )
235235
MtlCellArray{T,B}(::UndefInitializer, dims::Int...) where {T<:CellArrays.Cell,B} = MtlCellArray{T,B}(undef, dims)
236236
MtlCellArray{T}(::UndefInitializer, dims::NTuple{N,Int}) where {T<:CellArrays.Cell,N} = MtlCellArray{T,0}(undef, dims)
237237
MtlCellArray{T}(::UndefInitializer, dims::Int...) where {T<:CellArrays.Cell} = MtlCellArray{T}(undef, dims)
238238

239-
MtlCellArray(A::CellArrays.CellArray{T,N,B,T_array}) where {T,N,B,T_array} = CellArrays.CellArray{T,N,B}(Metal.MtlArray(A.data), A.dims)
239+
MtlCellArray(A::CellArrays.CellArray{T,N,B,T_array}) where {T,N,B,T_array} = ( A = CellArrays.CellArray{T,N,B}(Metal.MtlArray(A.data), A.dims); f(A)=(CellArrays.plain(A); return); if (B in (0,1)) @metal f(A) end; A )
240240

241241
Base.show(io::IO, A::MtlCellArray) = Base.show(io, CellArrays.CPUCellArray(A))
242242
Base.show(io::IO, ::MIME"text/plain", A::MtlCellArray{T,N,B}) where {T,N,B} = ( println(io, "$(length(A))-element MtlCellArray{$T, $N, $B, $(CellArrays.eltype(T))}:"); Base.print_array(io, CellArrays.CPUCellArray(A)) )

0 commit comments

Comments
 (0)