@@ -370,6 +370,156 @@ end
370370 end
371371end
372372
373+ # for cell in cells
374+ # for point in cell
375+ # for neighbor_cell in neighbor_cells
376+
377+ # for neighbor in neighbor_cell
378+ @inline function foreach_point_neighbor_localmem (f, system_coords, neighbor_coords,
379+ neighborhood_search; search_radius = search_radius (neighborhood_search))
380+ backend = KernelAbstractions. get_backend (system_coords)
381+ max_particles_per_cell = 64
382+ nhs_size = size (neighborhood_search. cell_list. linear_indices)
383+ # cells = CartesianIndices(ntuple(i -> 2:(nhs_size[i] - 1), ndims(neighborhood_search)))
384+ linear_indices = neighborhood_search. cell_list. linear_indices
385+ cartesian_indices = CartesianIndices (size (linear_indices))
386+ lengths = Array (neighborhood_search. cell_list. cells. lengths)
387+ # max_particles_per_cell = maximum(lengths)
388+ nonempty_cells = Adapt. adapt (backend, filter (index -> lengths[linear_indices[index]] > 0 , cartesian_indices))
389+ ndrange = max_particles_per_cell * length (nonempty_cells)
390+ kernel = foreach_neighbor_localmem (backend, (max_particles_per_cell,))
391+ kernel (f, system_coords, neighbor_coords, neighborhood_search, nonempty_cells, Val (max_particles_per_cell), search_radius; ndrange)
392+
393+ KernelAbstractions. synchronize (backend)
394+
395+ return nothing
396+ end
397+
398+ @kernel cpu= false function foreach_neighbor_localmem (f:: F , system_coords, neighbor_system_coords,
399+ neighborhood_search, cells, :: Val{MAX} , search_radius) where {F, MAX}
400+ cell_ = @index (Group)
401+ cell = @inbounds Tuple (cells[cell_])
402+ particleidx = @index (Local)
403+ @assert 1 <= particleidx <= MAX
404+
405+ local_points = @localmem Int32 MAX
406+ local_neighbor_coords = @localmem eltype (system_coords) (ndims (neighborhood_search), MAX)
407+
408+ pv = points_in_cell (cell, neighborhood_search)
409+ n_particles_in_current_cell = length (pv)
410+ if particleidx <= n_particles_in_current_cell
411+ point = @inbounds pv[particleidx]
412+ point_coords = @inbounds extract_svector (system_coords, Val (ndims (neighborhood_search)),
413+ point)
414+ # KernelAbstractions.@print("Point $point with coords ($(point_coords[1]), $(point_coords[2]))\n")
415+ else
416+ point = zero (Int32)
417+ point_coords = zero (SVector{ndims (neighborhood_search), eltype (system_coords)})
418+ end
419+
420+ for i in - 1 : 1 , j in - 1 : 1 , k in - 1 : 1
421+ neighbor_cell = (cell[1 ] + i, cell[2 ] + j, cell[3 ] + k)
422+ # for neighbor_cell_ in neighboring_cells(cell, neighborhood_search)
423+ # neighbor_cell = Tuple(neighbor_cell_)
424+ points_view = points_in_cell (neighbor_cell, neighborhood_search)
425+ n_particles_in_neighbor_cell = length (points_view)
426+ # if n_particles_in_neighbor_cell
427+ # continue
428+ # end
429+
430+ # First use all threads to load the neighbors into local memory in parallel
431+ if particleidx <= n_particles_in_neighbor_cell
432+ @inbounds p = local_points[particleidx] = points_view[particleidx]
433+ # KernelAbstractions.@print("Point $point, neighbor $p with coords ($(neighbor_system_coords[1, p]), $(neighbor_system_coords[2, p]))\n")
434+ for d in 1 : ndims (neighborhood_search)
435+ @inbounds local_neighbor_coords[d, particleidx] = neighbor_system_coords[d, p]
436+ end
437+ end
438+ @synchronize ()
439+ # Now each thread works on one point again
440+ if particleidx <= n_particles_in_current_cell
441+ for local_neighbor in 1 : n_particles_in_neighbor_cell
442+ @inbounds neighbor = local_points[local_neighbor]
443+ @inbounds neighbor_coords = extract_svector (local_neighbor_coords,
444+ Val (ndims (neighborhood_search)), local_neighbor)
445+
446+ pos_diff = point_coords - neighbor_coords
447+ distance2 = dot (pos_diff, pos_diff)
448+
449+ # TODO periodic
450+
451+ if distance2 <= search_radius^ 2
452+ # KernelAbstractions.@print("Point $point, neighbor $neighbor with distance2 $distance2\n")
453+ distance = sqrt (distance2) # TODO : eventuell fastmath
454+
455+ # Inline to avoid loss of performance
456+ # compared to not using `foreach_point_neighbor`.
457+ @inline f (point, neighbor, pos_diff, distance)
458+ end
459+ end
460+ end
461+ @synchronize ()
462+ end
463+ end
464+
465+ @inline function foreach_point_neighbor_cell_blocks (f, system_coords, neighbor_coords,
466+ neighborhood_search)
467+ backend = KernelAbstractions. get_backend (system_coords)
468+ max_particles_per_cell = 64
469+ nhs_size = size (neighborhood_search. cell_list. linear_indices)
470+ cells = CartesianIndices (ntuple (i -> 2 : (nhs_size[i] - 1 ), ndims (neighborhood_search)))
471+ ndrange = max_particles_per_cell * length (cells)
472+ kernel = foreach_neighbor_cell_blocks (backend, (max_particles_per_cell,))
473+ kernel (f, system_coords, neighbor_coords, neighborhood_search, Val (max_particles_per_cell); ndrange)
474+
475+ KernelAbstractions. synchronize (backend)
476+
477+ return nothing
478+ end
479+
480+ @kernel cpu= false function foreach_neighbor_cell_blocks (f:: F , system_coords, neighbor_coords,
481+ neighborhood_search, :: Val{MAX} ) where {F, MAX}
482+ cell_ = @index (Group)
483+ nhs_size = size (neighborhood_search. cell_list. linear_indices)
484+ @inbounds cells = CartesianIndices (ntuple (i -> 2 : (nhs_size[i] - 1 ), ndims (neighborhood_search)))
485+ cell = @inbounds Tuple (cells[cell_])
486+ particleidx = @index (Local)
487+ @assert 1 <= particleidx <= MAX
488+
489+ pv = points_in_cell (cell, neighborhood_search)
490+ n_particles_in_current_cell = length (pv)
491+ if particleidx <= n_particles_in_current_cell
492+ point = @inbounds pv[particleidx]
493+ point_coords = @inbounds extract_svector (system_coords, Val (ndims (neighborhood_search)),
494+ point)
495+ # KernelAbstractions.@print("Point $point with coords ($(point_coords[1]), $(point_coords[2]))\n")
496+
497+ for neighbor_cell_ in neighboring_cells (cell, neighborhood_search)
498+ neighbor_cell = Tuple (neighbor_cell_)
499+ points_view = points_in_cell (neighbor_cell, neighborhood_search)
500+
501+ for neighbor in points_view
502+ @inbounds neighbor_coords = extract_svector (neighbor_system_coords,
503+ Val (ndims (neighborhood_search)), neighbor)
504+
505+ pos_diff = point_coords - neighbor_coords
506+ distance2 = dot (pos_diff, pos_diff)
507+
508+ # TODO periodic
509+
510+ if distance2 <= search_radius^ 2
511+ # KernelAbstractions.@print("Point $point, neighbor $neighbor with distance2 $distance2\n")
512+ distance = sqrt (distance2) # TODO : eventuell fastmath
513+
514+ # Inline to avoid loss of performance
515+ # compared to not using `foreach_point_neighbor`.
516+ @inline f (point, neighbor, pos_diff, distance)
517+ end
518+ end
519+ end
520+ end
521+ end
522+
373523@inline function neighboring_cells (cell, neighborhood_search)
374524 NDIMS = ndims (neighborhood_search)
375525
0 commit comments