@@ -155,22 +155,30 @@ end
155155 # On the GPU, we can only loop over `1:N`. Therefore, we loop over `1:length(iterator)`
156156 # and index with `iterator[eachindex(iterator)[i]]`.
157157 # Note that this only works with vector-like iterators that support arbitrary indexing.
158- indices = eachindex (iterator)
158+ indices = eachindex (IndexLinear (), iterator)
159159 ndrange = length (indices)
160160
161+ # TODO : Is it better to pass `indices` to the kernel,
162+ # or should we "recreate" them inside the kernel.
163+
161164 # Skip empty loops
162165 ndrange == 0 && return
163166
164167 # Call the generic kernel that is defined below, which only calls a function with
165168 # the global GPU index.
166- generic_kernel (backend)(ndrange = ndrange) do i
167- @inbounds @inline f (iterator[indices[i]])
168- end
169+ foreach_ka (backend)(f, iterator, indices, ndrange = ndrange)
169170
170171 KernelAbstractions. synchronize (backend)
171172end
172173
173- @kernel function generic_kernel (f)
174- i = @index (Global)
175- @inline f (i)
174+ @kernel unsafe_indices= true function foreach_ka (f, iterator, indices)
175+ # Calculate global index
176+ N = @groupsize ()[1 ]
177+ iblock = @index (Group, Linear)
178+ ithread = @index (Local, Linear)
179+ i = ithread + (iblock - Int32 (1 )) * N
180+
181+ if i <= length (indices)
182+ @inbounds @inline f (iterator[indices[i]])
183+ end
176184end
0 commit comments