Skip to content

Commit 3e84145

Browse files
committed
Fix unified_array
1 parent 5bb91f0 commit 3e84145

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

ext/OceananigansCUDAExt.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@ AC.on_architecture(::CUDAGPU, a::SubArray{<:Any, <:Any, <:Array}) = CuArray(a)
6666
AC.on_architecture(::AC.CPU, a::SubArray{<:Any, <:Any, <:CuArray}) = Array(a)
6767
AC.on_architecture(::CUDAGPU, a::StepRangeLen) = a
6868

69+
# cu alters the type of `a`, so we convert it back to the correct type
70+
unified_array(::CUDAGPU, a::AbstractArray) = map(eltype(a), cu(a; unified = true))
71+
6972
## GPU to GPU copy of contiguous data
7073
@inline function AC.device_copy_to!(dst::CuArray, src::CuArray; async::Bool = false)
7174
n = length(src)

src/Architectures.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,6 @@ cpu_architecture(::ReactantState) = CPU()
104104
unified_array(::CPU, a) = a
105105
unified_array(::GPU, a) = a
106106

107-
# cu alters the type of `a`, so we convert it back to the correct type
108-
unified_array(::GPU, a::AbstractArray) = map(eltype(a), cu(a; unified = true))
109-
110107
@inline device_copy_to!(dst::Array, src::Array; kw...) = Base.copyto!(dst, src)
111108

112109
@inline unsafe_free!(a) = nothing

0 commit comments

Comments
 (0)