@@ -50,6 +50,18 @@ function load_neighbor_region(arr, region_code::NTuple{N,Int}, neigh_dist) where
5050 return move (task_processor (), collect (@view arr[start_idx: stop_idx]))
5151end
5252
53+ # In-place variant: load region directly into a pre-allocated destination buffer.
54+ function load_neighbor_region_into! (dest, arr, region_code:: NTuple{N,Int} , neigh_dist) where N
55+ validate_neigh_dist (neigh_dist, size (arr))
56+ start_idx = CartesianIndex (ntuple (N) do i
57+ region_code[i] == - 1 ? lastindex (arr, i) - get_neigh_dist (neigh_dist, i) + 1 : firstindex (arr, i)
58+ end )
59+ stop_idx = CartesianIndex (ntuple (N) do i
60+ region_code[i] == + 1 ? firstindex (arr, i) + get_neigh_dist (neigh_dist, i) - 1 : lastindex (arr, i)
61+ end )
62+ copyto! (dest, @view arr[start_idx: stop_idx])
63+ end
64+
5365is_past_boundary (size, idx) = any (ntuple (i -> idx[i] < 1 || idx[i] > size[i], length (size)))
5466
5567# ############################################################################
@@ -123,6 +135,9 @@ boundary_transition(::Wrap, idx, size) =
123135load_boundary_region (:: Wrap , arr, region_code, neigh_dist, boundary_dims) =
124136 load_neighbor_region (arr, region_code, neigh_dist)
125137
138+ load_boundary_region_into! (dest, :: Wrap , arr, region_code, neigh_dist, boundary_dims) =
139+ load_neighbor_region_into! (dest, arr, region_code, neigh_dist)
140+
126141function boundary_source_index (:: Wrap , arr, rc, nd, idx_d, d)
127142 if rc == - 1
128143 return lastindex (arr, d) - nd + idx_d
@@ -157,6 +172,9 @@ function load_boundary_region(pad::Pad, arr, region_code::NTuple{N,Int}, neigh_d
157172 return move (task_processor (), fill (pad. padval, region_size))
158173end
159174
175+ load_boundary_region_into! (dest, pad:: Pad , arr, region_code, neigh_dist, boundary_dims) =
176+ fill! (dest, pad. padval)
177+
160178# Use edge as source index (value will be overridden by apply_boundary_value)
161179boundary_source_index (:: Pad , arr, rc, nd, idx_d, d) =
162180 rc == - 1 ? firstindex (arr, d) : (rc == + 1 ? lastindex (arr, d) : idx_d)
@@ -221,6 +239,10 @@ function load_boundary_region(::Clamp, arr, region_code::NTuple{N,Int}, neigh_di
221239 return move (task_processor (), result)
222240end
223241
242+ function load_boundary_region_into! (dest, :: Clamp , arr, region_code:: NTuple{N,Int} , neigh_dist, boundary_dims:: NTuple{N,Bool} ) where N
243+ Kernel (load_boundary_region_kernel)(Clamp (), dest, arr, region_code, neigh_dist, boundary_dims; ndrange= length (dest))
244+ end
245+
224246function boundary_source_index (:: Clamp , arr, rc, nd, idx_d, d)
225247 if rc == - 1
226248 return firstindex (arr, d)
@@ -332,6 +354,18 @@ function load_boundary_region(::LinearExtrapolate, arr::AbstractArray{T}, region
332354 return move (task_processor (), result)
333355end
334356
357+ function load_boundary_region_into! (dest, :: LinearExtrapolate , arr:: AbstractArray{T} , region_code:: NTuple{N,Int} , neigh_dist, boundary_dims:: NTuple{N,Bool} ) where {T<: Real ,N}
358+ extrap_dim = 0
359+ for d in 1 : N
360+ if boundary_dims[d] && region_code[d] != 0
361+ extrap_dim = d
362+ break
363+ end
364+ end
365+ nd = get_neigh_dist (neigh_dist, extrap_dim)
366+ Kernel (load_boundary_region_kernel)(LinearExtrapolate (), dest, arr, region_code, neigh_dist, boundary_dims, Val (extrap_dim), Val (nd); ndrange= length (dest))
367+ end
368+
335369# Use edge as source index (value will be computed by apply_boundary_value)
336370boundary_source_index (:: LinearExtrapolate , arr, rc, nd, idx_d, d) =
337371 rc == - 1 ? firstindex (arr, d) : (rc == + 1 ? lastindex (arr, d) : idx_d)
@@ -434,6 +468,41 @@ function load_boundary_region(::Reflect{Symm}, arr, region_code::NTuple{N,Int},
434468 return region
435469end
436470
471+ function load_boundary_region_into! (dest, :: Reflect{Symm} , arr, region_code:: NTuple{N,Int} , neigh_dist, boundary_dims:: NTuple{N,Bool} ) where {N, Symm}
472+ flipped_code = ntuple (N) do i
473+ (region_code[i] != 0 && boundary_dims[i]) ? - region_code[i] : region_code[i]
474+ end
475+ skip = Symm ? 0 : 1
476+ start_idx = CartesianIndex (ntuple (N) do i
477+ needs_skip = boundary_dims[i] && region_code[i] != 0
478+ actual_skip = needs_skip ? skip : 0
479+ if flipped_code[i] == - 1
480+ lastindex (arr, i) - get_neigh_dist (neigh_dist, i) + 1 - actual_skip
481+ elseif flipped_code[i] == + 1
482+ firstindex (arr, i) + actual_skip
483+ else
484+ firstindex (arr, i)
485+ end
486+ end )
487+ stop_idx = CartesianIndex (ntuple (N) do i
488+ needs_skip = boundary_dims[i] && region_code[i] != 0
489+ actual_skip = needs_skip ? skip : 0
490+ if flipped_code[i] == + 1
491+ firstindex (arr, i) + get_neigh_dist (neigh_dist, i) - 1 + actual_skip
492+ elseif flipped_code[i] == - 1
493+ lastindex (arr, i) - actual_skip
494+ else
495+ lastindex (arr, i)
496+ end
497+ end )
498+ copyto! (dest, @view arr[start_idx: stop_idx])
499+ for i in 1 : N
500+ GPUArraysCore. @allowscalar if region_code[i] != 0 && boundary_dims[i]
501+ reverse! (dest, dims= i)
502+ end
503+ end
504+ end
505+
437506function boundary_source_index (:: Reflect{Symm} , arr, rc, nd, idx_d, d) where Symm
438507 skip = Symm ? 0 : 1
439508 if rc == - 1
@@ -564,6 +633,10 @@ function load_boundary_region(boundary::Tuple, arr, region_code::NTuple{N,Int},
564633 return move (task_processor (), result)
565634end
566635
636+ function load_boundary_region_into! (dest, boundary:: Tuple , arr, region_code:: NTuple{N,Int} , neigh_dist, boundary_dims:: NTuple{N,Bool} ) where N
637+ Kernel (load_boundary_region_kernel)(boundary, dest, arr, region_code, neigh_dist, boundary_dims; ndrange= length (dest))
638+ end
639+
567640# ############################################################################
568641# Chunk Selection and Halo Building
569642# ############################################################################
@@ -615,6 +688,99 @@ function select_neighborhood_chunks(chunks, idx, neigh_dist, boundary)
615688 return accesses
616689end
617690
691+ # Returns (region_metadata, neighbor_chunk_dtasks) without spawning intermediate load tasks.
692+ # region_metadata: Vector of (region_code, is_boundary, boundary_dims).
693+ # neighbor_chunk_dtasks: Vector of raw chunk DTasks (resolved to arrays when build_halo_consolidated runs).
694+ function select_neighborhood_info (chunks, idx, neigh_dist, boundary)
695+ validate_neigh_dist (neigh_dist)
696+ N = ndims (chunks)
697+ chunk_dist = 1
698+ region_metadata = Tuple[]
699+ neighbor_chunks = Any[]
700+
701+ for i in 0 : (3 ^ N - 1 )
702+ region_code = ntuple (N) do d
703+ ((i ÷ 3 ^ (d- 1 )) % 3 ) - 1
704+ end
705+ all (== (0 ), region_code) && continue
706+
707+ chunk_offset = CartesianIndex (ntuple (N) do d
708+ region_code[d] * chunk_dist
709+ end )
710+ new_idx = idx + chunk_offset
711+
712+ if is_past_boundary (size (chunks), new_idx)
713+ boundary_dims = ntuple (N) do d
714+ new_idx[d] < 1 || new_idx[d] > size (chunks)[d]
715+ end
716+ if boundary_has_transition (boundary)
717+ new_idx = boundary_transition (boundary, new_idx, size (chunks))
718+ else
719+ new_idx = idx
720+ end
721+ push! (region_metadata, (region_code, true , boundary_dims))
722+ else
723+ push! (region_metadata, (region_code, false , ntuple (_ -> false , N)))
724+ end
725+ push! (neighbor_chunks, chunks[new_idx])
726+ end
727+
728+ @assert length (region_metadata) == 3 ^ N - 1
729+ return region_metadata, neighbor_chunks
730+ end
731+
732+ # Per-thread cache: IdDict{DArray, Dict{(chunk_idx, halo_width), HaloArray}}.
733+ # Using IdDict for the outer level ensures two DArrays with identical element types and
734+ # chunk shapes never share a buffer. Using chunk_idx as part of the inner key ensures that
735+ # within one DArray, every chunk has its own dedicated buffer — so if a single worker thread
736+ # processes multiple same-shaped chunks in the same iteration (sequentially), each gets a
737+ # distinct HaloArray and there is no aliasing with a concurrently running inner-stencil task.
738+ # Filling a cached buffer in-place is safe because spawn_datadeps blocks until all inner
739+ # tasks complete before the next iteration's build_halo_consolidated calls run.
740+ const HALO_ARRAY_CACHE = TaskLocalValue {IdDict{Any,Dict{Any,Any}}} (()-> IdDict {Any,Dict{Any,Any}} ())
741+
742+ # Consolidated halo builder: loads all neighbor regions directly into a HaloArray.
743+ # `read_darray` and `chunk_idx` are used solely for cache lookup — they are not DTask
744+ # arguments, so Dagger does not create extra data dependencies from them.
745+ # First call per (DArray, chunk_idx, halo_width) allocates and caches; subsequent calls
746+ # fill the cached HaloArray in-place — zero allocations on the hot path.
747+ function build_halo_consolidated (read_darray, chunk_idx, neigh_dist, boundary, center, region_metadata, neighbor_chunks... )
748+ N = ndims (center)
749+ expected_halos = length (region_metadata)
750+ @assert length (neighbor_chunks) == expected_halos
751+ validate_neigh_dist (neigh_dist, size (center))
752+ halo_width = ntuple (i -> get_neigh_dist (neigh_dist, i), N)
753+
754+ outer_cache = HALO_ARRAY_CACHE[]
755+ inner_cache = get! (outer_cache, read_darray) do ; Dict {Any,Any} (); end
756+ cache_key = (chunk_idx, halo_width)
757+
758+ if haskey (inner_cache, cache_key)
759+ halo = inner_cache[cache_key]
760+ copyto! (halo. center, center)
761+ for i in 1 : expected_halos
762+ region_code, is_boundary, boundary_dims = region_metadata[i]
763+ chunk = neighbor_chunks[i]
764+ if is_boundary
765+ load_boundary_region_into! (halo. halos[i], boundary, chunk, region_code, neigh_dist, boundary_dims)
766+ else
767+ load_neighbor_region_into! (halo. halos[i], chunk, region_code, neigh_dist)
768+ end
769+ end
770+ return halo
771+ else
772+ halos = ntuple (expected_halos) do i
773+ region_code, is_boundary, boundary_dims = region_metadata[i]
774+ chunk = neighbor_chunks[i]
775+ is_boundary ? load_boundary_region (boundary, chunk, region_code, neigh_dist, boundary_dims) :
776+ load_neighbor_region (chunk, region_code, neigh_dist)
777+ end
778+ halo = HaloArray (copy (center), halos, halo_width)
779+ inner_cache[cache_key] = halo
780+ return halo
781+ end
782+ end
783+
618784function build_halo (neigh_dist, boundary, center, all_halos... )
619785 N = ndims (center)
620786 expected_halos = 3 ^ N - 1
@@ -833,11 +999,12 @@ macro stencil(orig_ex)
833999 for read_var in read_vars
8341000 if read_var in keys (neighborhoods)
8351001 neigh_dist, boundary = neighborhoods[read_var]
836- @gensym halo_tasks
1002+ @gensym halo_tasks region_meta neighbor_cks
8371003 push! (final_ex. args, :($ halo_tasks = Array {$DTask} (undef, size ($ chunks ($ read_var)))))
8381004 push! (final_ex. args, quote
8391005 for $ chunk_idx in $ CartesianIndices ($ chunks ($ read_var))
840- $ halo_tasks[$ chunk_idx] = Dagger. @spawn name= " stencil_build_halo" $ build_halo ($ neigh_dist, $ boundary, $ select_neighborhood_chunks ($ chunks ($ read_var), $ chunk_idx, $ neigh_dist, $ boundary)... )
1006+ ($ region_meta, $ neighbor_cks) = $ select_neighborhood_info ($ chunks ($ read_var), $ chunk_idx, $ neigh_dist, $ boundary)
1007+ $ halo_tasks[$ chunk_idx] = Dagger. @spawn name= " stencil_build_halo" $ build_halo_consolidated ($ read_var, $ chunk_idx, $ neigh_dist, $ boundary, $ chunks ($ read_var)[$ chunk_idx], $ region_meta, $ neighbor_cks... )
8411008 end
8421009 end )
8431010 push! (final_ex. args, :($ halo_tasks_map[$ (QuoteNode (read_var))] = $ halo_tasks))
0 commit comments