11import GPUArrays: allowscalar, @allowscalar
22
3- function _getindex(xs:: CuArray{T} , i:: Integer ) where T
4- buf = Array{T}(undef)
5- copyto!(buf, 1 , xs, i, 1 )
6- buf[]
3+
4+ # # unified memory indexing
5+
6+ # TODO : needs to think about coherency -- otherwise this might crash since it doesn't sync
7+ # also, this optim would be relevant for CuArray<->Array memcpy as well.
8+
9+ function GPUArrays. _getindex(xs:: CuArray{T} , i:: Integer ) where T
10+ buf = buffer(xs)
11+ if isa(buf, Mem. UnifiedBuffer)
12+ ptr = convert(Ptr{T}, buffer(xs))
13+ unsafe_load(ptr, i)
14+ else
15+ val = Array{T}(undef)
16+ copyto!(val, 1 , xs, i, 1 )
17+ val[]
18+ end
719end
820
9- function _setindex!(xs:: CuArray{T} , v:: T , i:: Integer ) where T
10- copyto!(xs, i, T[v], 1 , 1 )
21+ function GPUArrays. _setindex!(xs:: CuArray{T} , v:: T , i:: Integer ) where T
22+ buf = buffer(xs)
23+ if isa(buf, Mem. UnifiedBuffer)
24+ ptr = convert(Ptr{T}, buffer(xs))
25+ unsafe_store!(ptr, v, i)
26+ else
27+ copyto!(xs, i, T[v], 1 , 1 )
28+ end
1129end
1230
1331
@@ -19,7 +37,7 @@ function Base.getindex(xs::CuArray{T}, bools::CuArray{Bool}) where {T}
1937 bools = reshape(bools, prod(size(bools)))
2038 indices = cumsum(bools) # unique indices for elements that are true
2139
22- n = _getindex(indices, length(indices)) # number that are true
40+ n = GPUArrays . _getindex(indices, length(indices)) # number that are true
2341 ys = CuArray{T}(undef, n)
2442
2543 if n > 0
0 commit comments