Skip to content

Commit 9370790

Browse files
committed
Add GPU kernel using shared memory
1 parent d0fe7f6 commit 9370790

File tree

1 file changed

+150
-0
lines changed

1 file changed

+150
-0
lines changed

src/nhs_grid.jl

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,156 @@ end
370370
end
371371
end
372372

373+
# for cell in cells
374+
# for point in cell
375+
# for neighbor_cell in neighbor_cells
376+
377+
# for neighbor in neighbor_cell
378+
@inline function foreach_point_neighbor_localmem(f, system_coords, neighbor_coords,
379+
neighborhood_search; search_radius = search_radius(neighborhood_search))
380+
backend = KernelAbstractions.get_backend(system_coords)
381+
max_particles_per_cell = 64
382+
nhs_size = size(neighborhood_search.cell_list.linear_indices)
383+
# cells = CartesianIndices(ntuple(i -> 2:(nhs_size[i] - 1), ndims(neighborhood_search)))
384+
linear_indices = neighborhood_search.cell_list.linear_indices
385+
cartesian_indices = CartesianIndices(size(linear_indices))
386+
lengths = Array(neighborhood_search.cell_list.cells.lengths)
387+
# max_particles_per_cell = maximum(lengths)
388+
nonempty_cells = Adapt.adapt(backend, filter(index -> lengths[linear_indices[index]] > 0, cartesian_indices))
389+
ndrange = max_particles_per_cell * length(nonempty_cells)
390+
kernel = foreach_neighbor_localmem(backend, (max_particles_per_cell,))
391+
kernel(f, system_coords, neighbor_coords, neighborhood_search, nonempty_cells, Val(max_particles_per_cell), search_radius; ndrange)
392+
393+
KernelAbstractions.synchronize(backend)
394+
395+
return nothing
396+
end
397+
398+
@kernel cpu=false function foreach_neighbor_localmem(f::F, system_coords, neighbor_system_coords,
399+
neighborhood_search, cells, ::Val{MAX}, search_radius) where {F, MAX}
400+
cell_ = @index(Group)
401+
cell = @inbounds Tuple(cells[cell_])
402+
particleidx = @index(Local)
403+
@assert 1 <= particleidx <= MAX
404+
405+
local_points = @localmem Int32 MAX
406+
local_neighbor_coords = @localmem eltype(system_coords) (ndims(neighborhood_search), MAX)
407+
408+
pv = points_in_cell(cell, neighborhood_search)
409+
n_particles_in_current_cell = length(pv)
410+
if particleidx <= n_particles_in_current_cell
411+
point = @inbounds pv[particleidx]
412+
point_coords = @inbounds extract_svector(system_coords, Val(ndims(neighborhood_search)),
413+
point)
414+
# KernelAbstractions.@print("Point $point with coords ($(point_coords[1]), $(point_coords[2]))\n")
415+
else
416+
point = zero(Int32)
417+
point_coords = zero(SVector{ndims(neighborhood_search), eltype(system_coords)})
418+
end
419+
420+
for i in -1:1, j in -1:1, k in -1:1
421+
neighbor_cell = (cell[1] + i, cell[2] + j, cell[3] + k)
422+
# for neighbor_cell_ in neighboring_cells(cell, neighborhood_search)
423+
# neighbor_cell = Tuple(neighbor_cell_)
424+
points_view = points_in_cell(neighbor_cell, neighborhood_search)
425+
n_particles_in_neighbor_cell = length(points_view)
426+
# if n_particles_in_neighbor_cell
427+
# continue
428+
# end
429+
430+
# First use all threads to load the neighbors into local memory in parallel
431+
if particleidx <= n_particles_in_neighbor_cell
432+
@inbounds p = local_points[particleidx] = points_view[particleidx]
433+
# KernelAbstractions.@print("Point $point, neighbor $p with coords ($(neighbor_system_coords[1, p]), $(neighbor_system_coords[2, p]))\n")
434+
for d in 1:ndims(neighborhood_search)
435+
@inbounds local_neighbor_coords[d, particleidx] = neighbor_system_coords[d, p]
436+
end
437+
end
438+
@synchronize()
439+
# Now each thread works on one point again
440+
if particleidx <= n_particles_in_current_cell
441+
for local_neighbor in 1:n_particles_in_neighbor_cell
442+
@inbounds neighbor = local_points[local_neighbor]
443+
@inbounds neighbor_coords = extract_svector(local_neighbor_coords,
444+
Val(ndims(neighborhood_search)), local_neighbor)
445+
446+
pos_diff = point_coords - neighbor_coords
447+
distance2 = dot(pos_diff, pos_diff)
448+
449+
# TODO periodic
450+
451+
if distance2 <= search_radius^2
452+
# KernelAbstractions.@print("Point $point, neighbor $neighbor with distance2 $distance2\n")
453+
distance = sqrt(distance2) # TODO: eventuell fastmath
454+
455+
# Inline to avoid loss of performance
456+
# compared to not using `foreach_point_neighbor`.
457+
@inline f(point, neighbor, pos_diff, distance)
458+
end
459+
end
460+
end
461+
@synchronize()
462+
end
463+
end
464+
465+
@inline function foreach_point_neighbor_cell_blocks(f, system_coords, neighbor_coords,
466+
neighborhood_search)
467+
backend = KernelAbstractions.get_backend(system_coords)
468+
max_particles_per_cell = 64
469+
nhs_size = size(neighborhood_search.cell_list.linear_indices)
470+
cells = CartesianIndices(ntuple(i -> 2:(nhs_size[i] - 1), ndims(neighborhood_search)))
471+
ndrange = max_particles_per_cell * length(cells)
472+
kernel = foreach_neighbor_cell_blocks(backend, (max_particles_per_cell,))
473+
kernel(f, system_coords, neighbor_coords, neighborhood_search, Val(max_particles_per_cell); ndrange)
474+
475+
KernelAbstractions.synchronize(backend)
476+
477+
return nothing
478+
end
479+
480+
@kernel cpu=false function foreach_neighbor_cell_blocks(f::F, system_coords, neighbor_coords,
481+
neighborhood_search, ::Val{MAX}) where {F, MAX}
482+
cell_ = @index(Group)
483+
nhs_size = size(neighborhood_search.cell_list.linear_indices)
484+
@inbounds cells = CartesianIndices(ntuple(i -> 2:(nhs_size[i] - 1), ndims(neighborhood_search)))
485+
cell = @inbounds Tuple(cells[cell_])
486+
particleidx = @index(Local)
487+
@assert 1 <= particleidx <= MAX
488+
489+
pv = points_in_cell(cell, neighborhood_search)
490+
n_particles_in_current_cell = length(pv)
491+
if particleidx <= n_particles_in_current_cell
492+
point = @inbounds pv[particleidx]
493+
point_coords = @inbounds extract_svector(system_coords, Val(ndims(neighborhood_search)),
494+
point)
495+
# KernelAbstractions.@print("Point $point with coords ($(point_coords[1]), $(point_coords[2]))\n")
496+
497+
for neighbor_cell_ in neighboring_cells(cell, neighborhood_search)
498+
neighbor_cell = Tuple(neighbor_cell_)
499+
points_view = points_in_cell(neighbor_cell, neighborhood_search)
500+
501+
for neighbor in points_view
502+
@inbounds neighbor_coords = extract_svector(neighbor_system_coords,
503+
Val(ndims(neighborhood_search)), neighbor)
504+
505+
pos_diff = point_coords - neighbor_coords
506+
distance2 = dot(pos_diff, pos_diff)
507+
508+
# TODO periodic
509+
510+
if distance2 <= search_radius^2
511+
# KernelAbstractions.@print("Point $point, neighbor $neighbor with distance2 $distance2\n")
512+
distance = sqrt(distance2) # TODO: eventuell fastmath
513+
514+
# Inline to avoid loss of performance
515+
# compared to not using `foreach_point_neighbor`.
516+
@inline f(point, neighbor, pos_diff, distance)
517+
end
518+
end
519+
end
520+
end
521+
end
522+
373523
@inline function neighboring_cells(cell, neighborhood_search)
374524
NDIMS = ndims(neighborhood_search)
375525

0 commit comments

Comments
 (0)