Skip to content

Commit 1682f3f

Browse files
committed
DArray/stencil: Reuse HaloArray allocations
1 parent fd5f4d2 commit 1682f3f

1 file changed

Lines changed: 169 additions & 2 deletions

File tree

src/array/stencil.jl

Lines changed: 169 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,18 @@ function load_neighbor_region(arr, region_code::NTuple{N,Int}, neigh_dist) where
5050
return move(task_processor(), collect(@view arr[start_idx:stop_idx]))
5151
end
5252

53+
# In-place variant: load region directly into a pre-allocated destination buffer.
54+
function load_neighbor_region_into!(dest, arr, region_code::NTuple{N,Int}, neigh_dist) where N
55+
validate_neigh_dist(neigh_dist, size(arr))
56+
start_idx = CartesianIndex(ntuple(N) do i
57+
region_code[i] == -1 ? lastindex(arr, i) - get_neigh_dist(neigh_dist, i) + 1 : firstindex(arr, i)
58+
end)
59+
stop_idx = CartesianIndex(ntuple(N) do i
60+
region_code[i] == +1 ? firstindex(arr, i) + get_neigh_dist(neigh_dist, i) - 1 : lastindex(arr, i)
61+
end)
62+
copyto!(dest, @view arr[start_idx:stop_idx])
63+
end
64+
5365
is_past_boundary(size, idx) = any(ntuple(i -> idx[i] < 1 || idx[i] > size[i], length(size)))
5466

5567
#############################################################################
@@ -123,6 +135,9 @@ boundary_transition(::Wrap, idx, size) =
123135
load_boundary_region(::Wrap, arr, region_code, neigh_dist, boundary_dims) =
124136
load_neighbor_region(arr, region_code, neigh_dist)
125137

138+
load_boundary_region_into!(dest, ::Wrap, arr, region_code, neigh_dist, boundary_dims) =
139+
load_neighbor_region_into!(dest, arr, region_code, neigh_dist)
140+
126141
function boundary_source_index(::Wrap, arr, rc, nd, idx_d, d)
127142
if rc == -1
128143
return lastindex(arr, d) - nd + idx_d
@@ -157,6 +172,9 @@ function load_boundary_region(pad::Pad, arr, region_code::NTuple{N,Int}, neigh_d
157172
return move(task_processor(), fill(pad.padval, region_size))
158173
end
159174

175+
load_boundary_region_into!(dest, pad::Pad, arr, region_code, neigh_dist, boundary_dims) =
176+
fill!(dest, pad.padval)
177+
160178
# Use edge as source index (value will be overridden by apply_boundary_value)
161179
boundary_source_index(::Pad, arr, rc, nd, idx_d, d) =
162180
rc == -1 ? firstindex(arr, d) : (rc == +1 ? lastindex(arr, d) : idx_d)
@@ -221,6 +239,10 @@ function load_boundary_region(::Clamp, arr, region_code::NTuple{N,Int}, neigh_di
221239
return move(task_processor(), result)
222240
end
223241

242+
function load_boundary_region_into!(dest, ::Clamp, arr, region_code::NTuple{N,Int}, neigh_dist, boundary_dims::NTuple{N,Bool}) where N
243+
Kernel(load_boundary_region_kernel)(Clamp(), dest, arr, region_code, neigh_dist, boundary_dims; ndrange=length(dest))
244+
end
245+
224246
function boundary_source_index(::Clamp, arr, rc, nd, idx_d, d)
225247
if rc == -1
226248
return firstindex(arr, d)
@@ -332,6 +354,18 @@ function load_boundary_region(::LinearExtrapolate, arr::AbstractArray{T}, region
332354
return move(task_processor(), result)
333355
end
334356

357+
function load_boundary_region_into!(dest, ::LinearExtrapolate, arr::AbstractArray{T}, region_code::NTuple{N,Int}, neigh_dist, boundary_dims::NTuple{N,Bool}) where {T<:Real,N}
358+
extrap_dim = 0
359+
for d in 1:N
360+
if boundary_dims[d] && region_code[d] != 0
361+
extrap_dim = d
362+
break
363+
end
364+
end
365+
nd = get_neigh_dist(neigh_dist, extrap_dim)
366+
Kernel(load_boundary_region_kernel)(LinearExtrapolate(), dest, arr, region_code, neigh_dist, boundary_dims, Val(extrap_dim), Val(nd); ndrange=length(dest))
367+
end
368+
335369
# Use edge as source index (value will be computed by apply_boundary_value)
336370
boundary_source_index(::LinearExtrapolate, arr, rc, nd, idx_d, d) =
337371
rc == -1 ? firstindex(arr, d) : (rc == +1 ? lastindex(arr, d) : idx_d)
@@ -434,6 +468,41 @@ function load_boundary_region(::Reflect{Symm}, arr, region_code::NTuple{N,Int},
434468
return region
435469
end
436470

471+
function load_boundary_region_into!(dest, ::Reflect{Symm}, arr, region_code::NTuple{N,Int}, neigh_dist, boundary_dims::NTuple{N,Bool}) where {N, Symm}
472+
flipped_code = ntuple(N) do i
473+
(region_code[i] != 0 && boundary_dims[i]) ? -region_code[i] : region_code[i]
474+
end
475+
skip = Symm ? 0 : 1
476+
start_idx = CartesianIndex(ntuple(N) do i
477+
needs_skip = boundary_dims[i] && region_code[i] != 0
478+
actual_skip = needs_skip ? skip : 0
479+
if flipped_code[i] == -1
480+
lastindex(arr, i) - get_neigh_dist(neigh_dist, i) + 1 - actual_skip
481+
elseif flipped_code[i] == +1
482+
firstindex(arr, i) + actual_skip
483+
else
484+
firstindex(arr, i)
485+
end
486+
end)
487+
stop_idx = CartesianIndex(ntuple(N) do i
488+
needs_skip = boundary_dims[i] && region_code[i] != 0
489+
actual_skip = needs_skip ? skip : 0
490+
if flipped_code[i] == +1
491+
firstindex(arr, i) + get_neigh_dist(neigh_dist, i) - 1 + actual_skip
492+
elseif flipped_code[i] == -1
493+
lastindex(arr, i) - actual_skip
494+
else
495+
lastindex(arr, i)
496+
end
497+
end)
498+
copyto!(dest, @view arr[start_idx:stop_idx])
499+
for i in 1:N
500+
GPUArraysCore.@allowscalar if region_code[i] != 0 && boundary_dims[i]
501+
reverse!(dest, dims=i)
502+
end
503+
end
504+
end
505+
437506
function boundary_source_index(::Reflect{Symm}, arr, rc, nd, idx_d, d) where Symm
438507
skip = Symm ? 0 : 1
439508
if rc == -1
@@ -564,6 +633,10 @@ function load_boundary_region(boundary::Tuple, arr, region_code::NTuple{N,Int},
564633
return move(task_processor(), result)
565634
end
566635

636+
function load_boundary_region_into!(dest, boundary::Tuple, arr, region_code::NTuple{N,Int}, neigh_dist, boundary_dims::NTuple{N,Bool}) where N
637+
Kernel(load_boundary_region_kernel)(boundary, dest, arr, region_code, neigh_dist, boundary_dims; ndrange=length(dest))
638+
end
639+
567640
#############################################################################
568641
# Chunk Selection and Halo Building
569642
#############################################################################
@@ -615,6 +688,99 @@ function select_neighborhood_chunks(chunks, idx, neigh_dist, boundary)
615688
return accesses
616689
end
617690

691+
# Returns (region_metadata, neighbor_chunk_dtasks) without spawning intermediate load tasks.
692+
# region_metadata: Vector of (region_code, is_boundary, boundary_dims).
693+
# neighbor_chunk_dtasks: Vector of raw chunk DTasks (resolved to arrays when build_halo_consolidated runs).
694+
function select_neighborhood_info(chunks, idx, neigh_dist, boundary)
695+
validate_neigh_dist(neigh_dist)
696+
N = ndims(chunks)
697+
chunk_dist = 1
698+
region_metadata = Tuple[]
699+
neighbor_chunks = Any[]
700+
701+
for i in 0:(3^N - 1)
702+
region_code = ntuple(N) do d
703+
((i ÷ 3^(d-1)) % 3) - 1
704+
end
705+
all(==(0), region_code) && continue
706+
707+
chunk_offset = CartesianIndex(ntuple(N) do d
708+
region_code[d] * chunk_dist
709+
end)
710+
new_idx = idx + chunk_offset
711+
712+
if is_past_boundary(size(chunks), new_idx)
713+
boundary_dims = ntuple(N) do d
714+
new_idx[d] < 1 || new_idx[d] > size(chunks)[d]
715+
end
716+
if boundary_has_transition(boundary)
717+
new_idx = boundary_transition(boundary, new_idx, size(chunks))
718+
else
719+
new_idx = idx
720+
end
721+
push!(region_metadata, (region_code, true, boundary_dims))
722+
else
723+
push!(region_metadata, (region_code, false, ntuple(_ -> false, N)))
724+
end
725+
push!(neighbor_chunks, chunks[new_idx])
726+
end
727+
728+
@assert length(region_metadata) == 3^N - 1
729+
return region_metadata, neighbor_chunks
730+
end
731+
732+
# Per-thread cache: IdDict{DArray, Dict{(chunk_idx, halo_width), HaloArray}}.
733+
# Using IdDict for the outer level ensures two DArrays with identical element types and
734+
# chunk shapes never share a buffer. Using chunk_idx as part of the inner key ensures that
735+
# within one DArray, every chunk has its own dedicated buffer — so if a single worker thread
736+
# processes multiple same-shaped chunks in the same iteration (sequentially), each gets a
737+
# distinct HaloArray and there is no aliasing with a concurrently running inner-stencil task.
738+
# Filling a cached buffer in-place is safe because spawn_datadeps blocks until all inner
739+
# tasks complete before the next iteration's build_halo_consolidated calls run.
740+
const HALO_ARRAY_CACHE = TaskLocalValue{IdDict{Any,Dict{Any,Any}}}(()->IdDict{Any,Dict{Any,Any}}())
741+
742+
# Consolidated halo builder: loads all neighbor regions directly into a HaloArray.
743+
# `read_darray` and `chunk_idx` are used solely for cache lookup — they are not DTask
744+
# arguments, so Dagger does not create extra data dependencies from them.
745+
# First call per (DArray, chunk_idx, halo_width) allocates and caches; subsequent calls
746+
# fill the cached HaloArray in-place — zero allocations on the hot path.
747+
function build_halo_consolidated(read_darray, chunk_idx, neigh_dist, boundary, center, region_metadata, neighbor_chunks...)
748+
N = ndims(center)
749+
expected_halos = length(region_metadata)
750+
@assert length(neighbor_chunks) == expected_halos
751+
validate_neigh_dist(neigh_dist, size(center))
752+
halo_width = ntuple(i -> get_neigh_dist(neigh_dist, i), N)
753+
754+
outer_cache = HALO_ARRAY_CACHE[]
755+
inner_cache = get!(outer_cache, read_darray) do; Dict{Any,Any}(); end
756+
cache_key = (chunk_idx, halo_width)
757+
758+
if haskey(inner_cache, cache_key)
759+
halo = inner_cache[cache_key]
760+
copyto!(halo.center, center)
761+
for i in 1:expected_halos
762+
region_code, is_boundary, boundary_dims = region_metadata[i]
763+
chunk = neighbor_chunks[i]
764+
if is_boundary
765+
load_boundary_region_into!(halo.halos[i], boundary, chunk, region_code, neigh_dist, boundary_dims)
766+
else
767+
load_neighbor_region_into!(halo.halos[i], chunk, region_code, neigh_dist)
768+
end
769+
end
770+
return halo
771+
else
772+
halos = ntuple(expected_halos) do i
773+
region_code, is_boundary, boundary_dims = region_metadata[i]
774+
chunk = neighbor_chunks[i]
775+
is_boundary ? load_boundary_region(boundary, chunk, region_code, neigh_dist, boundary_dims) :
776+
load_neighbor_region(chunk, region_code, neigh_dist)
777+
end
778+
halo = HaloArray(copy(center), halos, halo_width)
779+
inner_cache[cache_key] = halo
780+
return halo
781+
end
782+
end
783+
618784
function build_halo(neigh_dist, boundary, center, all_halos...)
619785
N = ndims(center)
620786
expected_halos = 3^N - 1
@@ -833,11 +999,12 @@ macro stencil(orig_ex)
833999
for read_var in read_vars
8341000
if read_var in keys(neighborhoods)
8351001
neigh_dist, boundary = neighborhoods[read_var]
836-
@gensym halo_tasks
1002+
@gensym halo_tasks region_meta neighbor_cks
8371003
push!(final_ex.args, :($halo_tasks = Array{$DTask}(undef, size($chunks($read_var)))))
8381004
push!(final_ex.args, quote
8391005
for $chunk_idx in $CartesianIndices($chunks($read_var))
840-
$halo_tasks[$chunk_idx] = Dagger.@spawn name="stencil_build_halo" $build_halo($neigh_dist, $boundary, $select_neighborhood_chunks($chunks($read_var), $chunk_idx, $neigh_dist, $boundary)...)
1006+
($region_meta, $neighbor_cks) = $select_neighborhood_info($chunks($read_var), $chunk_idx, $neigh_dist, $boundary)
1007+
$halo_tasks[$chunk_idx] = Dagger.@spawn name="stencil_build_halo" $build_halo_consolidated($read_var, $chunk_idx, $neigh_dist, $boundary, $chunks($read_var)[$chunk_idx], $region_meta, $neighbor_cks...)
8411008
end
8421009
end)
8431010
push!(final_ex.args, :($halo_tasks_map[$(QuoteNode(read_var))] = $halo_tasks))

0 commit comments

Comments
 (0)