Skip to content

Commit a2137c7

Browse files
committed
Use unsafe_indices
1 parent af93949 commit a2137c7

File tree

1 file changed

+15
-7
lines changed

1 file changed

+15
-7
lines changed

src/util.jl

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -155,22 +155,30 @@ end
155155
# On the GPU, we can only loop over `1:N`. Therefore, we loop over `1:length(iterator)`
156156
# and index with `iterator[eachindex(iterator)[i]]`.
157157
# Note that this only works with vector-like iterators that support arbitrary indexing.
158-
indices = eachindex(iterator)
158+
indices = eachindex(IndexLinear(), iterator)
159159
ndrange = length(indices)
160160

161+
# TODO: Is it better to pass `indices` to the kernel,
162+
# or should we "recreate" them inside the kernel.
163+
161164
# Skip empty loops
162165
ndrange == 0 && return
163166

164167
# Call the generic kernel that is defined below, which only calls a function with
165168
# the global GPU index.
166-
generic_kernel(backend)(ndrange = ndrange) do i
167-
@inbounds @inline f(iterator[indices[i]])
168-
end
169+
foreach_ka(backend)(f, iterator, indices, ndrange = ndrange)
169170

170171
KernelAbstractions.synchronize(backend)
171172
end
172173

173-
@kernel function generic_kernel(f)
174-
i = @index(Global)
175-
@inline f(i)
174+
@kernel unsafe_indices=true function foreach_ka(f, iterator, indices)
175+
# Calculate global index
176+
N = @groupsize()[1]
177+
iblock = @index(Group, Linear)
178+
ithread = @index(Local, Linear)
179+
i = ithread + (iblock - Int32(1)) * N
180+
181+
if i <= length(indices)
182+
@inbounds @inline f(iterator[indices[i]])
183+
end
176184
end

0 commit comments

Comments
 (0)