Skip to content

Commit 7a1416b

Browse files
committed
DArray/stencil: Reduce memory allocations
1 parent 0af8688 commit 7a1416b

7 files changed

Lines changed: 161 additions & 25 deletions

File tree

ext/CUDAExt.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -370,11 +370,13 @@ CuArray(H::Dagger.HaloArray) = convert(CuArray, H)
370370
Base.convert(::Type{C}, H::Dagger.HaloArray) where {C<:CuArray} =
371371
Dagger.HaloArray(C(H.center),
372372
C.(H.halos),
373-
H.halo_width)
373+
H.halo_width;
374+
own_center=H.own_center)
374375
Adapt.adapt_structure(to::CUDA.KernelAdaptor, H::Dagger.HaloArray) =
375376
Dagger.HaloArray(adapt(to, H.center),
376377
adapt.(Ref(to), H.halos),
377-
H.halo_width)
378+
H.halo_width;
379+
own_center=H.own_center)
378380
function Dagger.inner_stencil_proc!(::CuArrayDeviceProc, f, output, read_vars)
379381
Dagger.Kernel(_inner_stencil!)(f, output, read_vars; ndrange=size(output))
380382
return

ext/IntelExt.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -322,11 +322,13 @@ oneArray(H::Dagger.HaloArray) = convert(oneArray, H)
322322
Base.convert(::Type{C}, H::Dagger.HaloArray) where {C<:oneArray} =
323323
Dagger.HaloArray(C(H.center),
324324
C.(H.halos),
325-
H.halo_width)
325+
H.halo_width;
326+
own_center=H.own_center)
326327
Adapt.adapt_structure(to::oneAPI.KernelAdaptor, H::Dagger.HaloArray) =
327328
Dagger.HaloArray(adapt(to, H.center),
328329
adapt.(Ref(to), H.halos),
329-
H.halo_width)
330+
H.halo_width;
331+
own_center=H.own_center)
330332
function Dagger.inner_stencil_proc!(::oneArrayDeviceProc, f, output, read_vars)
331333
Dagger.Kernel(_inner_stencil!)(f, output, read_vars; ndrange=size(output))
332334
return

ext/MetalExt.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -346,11 +346,13 @@ MtlArray(H::Dagger.HaloArray) = convert(MtlArray, H)
346346
Base.convert(::Type{C}, H::Dagger.HaloArray) where {C<:MtlArray} =
347347
Dagger.HaloArray(C(H.center),
348348
C.(H.halos),
349-
H.halo_width)
349+
H.halo_width;
350+
own_center=H.own_center)
350351
Adapt.adapt_structure(to::Metal.Adaptor, H::Dagger.HaloArray) =
351352
Dagger.HaloArray(adapt(to, H.center),
352353
adapt.(Ref(to), H.halos),
353-
H.halo_width)
354+
H.halo_width;
355+
own_center=H.own_center)
354356
function Dagger.inner_stencil_proc!(::MtlArrayDeviceProc, f, output, read_vars)
355357
Dagger.Kernel(_inner_stencil!)(f, output, read_vars; ndrange=size(output))
356358
return

ext/OpenCLExt.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -320,11 +320,13 @@ CLArray(H::Dagger.HaloArray) = convert(CLArray, H)
320320
Base.convert(::Type{C}, H::Dagger.HaloArray) where {C<:CLArray} =
321321
Dagger.HaloArray(C(H.center),
322322
C.(H.halos),
323-
H.halo_width)
323+
H.halo_width;
324+
own_center=H.own_center)
324325
Adapt.adapt_structure(to::OpenCL.KernelAdaptor, H::Dagger.HaloArray) =
325326
Dagger.HaloArray(adapt(to, H.center),
326327
adapt.(Ref(to), H.halos),
327-
H.halo_width)
328+
H.halo_width;
329+
own_center=H.own_center)
328330
function Dagger.inner_stencil_proc!(::CLArrayDeviceProc, f, output, read_vars)
329331
Dagger.Kernel(_inner_stencil!)(f, output, read_vars; ndrange=size(output))
330332
return

ext/ROCExt.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -343,11 +343,13 @@ ROCArray(H::Dagger.HaloArray) = convert(ROCArray, H)
343343
Base.convert(::Type{C}, H::Dagger.HaloArray) where {C<:ROCArray} =
344344
Dagger.HaloArray(C(H.center),
345345
C.(H.halos),
346-
H.halo_width)
346+
H.halo_width;
347+
own_center=H.own_center)
347348
Adapt.adapt_structure(to::AMDGPU.Runtime.Adaptor, H::Dagger.HaloArray) =
348349
Dagger.HaloArray(adapt(to, H.center),
349350
adapt.(Ref(to), H.halos),
350-
H.halo_width)
351+
H.halo_width;
352+
own_center=H.own_center)
351353
function Dagger.inner_stencil_proc!(::ROCArrayDeviceProc, f, output, read_vars)
352354
Dagger.Kernel(_inner_stencil!)(f, output, read_vars; ndrange=size(output))
353355
return

src/array/stencil.jl

Lines changed: 127 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,7 @@ function load_neighbor_region(arr, region_code::NTuple{N,Int}, neigh_dist) where
5454
lastindex(arr, i)
5555
end
5656
end)
57-
# FIXME: Don't collect
58-
return move(task_processor(), collect(@view arr[start_idx:stop_idx]))
57+
return move(task_processor(), copy(@view arr[start_idx:stop_idx]))
5958
end
6059

6160
# In-place variant: load region directly into a pre-allocated destination buffer.
@@ -176,8 +175,9 @@ function load_boundary_region(pad::Pad, arr, region_code::NTuple{N,Int}, neigh_d
176175
region_size = ntuple(N) do i
177176
region_code[i] == 0 ? size(arr, i) : get_neigh_dist(neigh_dist, i)
178177
end
179-
# FIXME: return Fill(pad.padval, region_size)
180-
return move(task_processor(), fill(pad.padval, region_size))
178+
result = similar(arr, region_size...)
179+
fill!(result, pad.padval)
180+
return move(task_processor(), result)
181181
end
182182

183183
load_boundary_region_into!(dest, pad::Pad, arr, region_code, neigh_dist, boundary_dims) =
@@ -462,7 +462,7 @@ function load_boundary_region(::Reflect{Symm}, arr, region_code::NTuple{N,Int},
462462
end
463463
end)
464464

465-
region = move(task_processor(), collect(@view arr[start_idx:stop_idx]))
465+
region = move(task_processor(), copy(@view arr[start_idx:stop_idx]))
466466

467467
# Reverse only along dimensions that are actually being reflected
468468
# (both non-zero in region_code AND past boundary)
@@ -649,6 +649,123 @@ end
649649
# Chunk Selection and Halo Building
650650
#############################################################################
651651

652+
function load_neighborhood_halos(chunks, idx, neigh_dist, boundary)
653+
validate_neigh_dist(neigh_dist)
654+
655+
N = ndims(chunks)
656+
chunk_dist = 1
657+
nhalos = 3^N - 1
658+
halos = Vector{Any}(undef, nhalos)
659+
h = 0
660+
661+
for i in 0:(3^N - 1)
662+
region_code = ntuple(N) do d
663+
((i ÷ 3^(d-1)) % 3) - 1
664+
end
665+
all(==(0), region_code) && continue
666+
h += 1
667+
668+
chunk_offset = CartesianIndex(ntuple(N) do d
669+
region_code[d] * chunk_dist
670+
end)
671+
new_idx = idx + chunk_offset
672+
673+
if is_past_boundary(size(chunks), new_idx)
674+
boundary_dims = ntuple(N) do d
675+
new_idx[d] < 1 || new_idx[d] > size(chunks)[d]
676+
end
677+
if boundary_has_transition(boundary)
678+
new_idx = boundary_transition(boundary, new_idx, size(chunks))
679+
else
680+
new_idx = idx
681+
end
682+
chunk = chunks[new_idx]
683+
halos[h] = load_boundary_region(boundary, chunk, region_code, neigh_dist, boundary_dims)
684+
else
685+
chunk = chunks[new_idx]
686+
halos[h] = load_neighbor_region(chunk, region_code, neigh_dist)
687+
end
688+
end
689+
690+
@assert h == nhalos
691+
return Tuple(halos)
692+
end
693+
694+
function load_neighborhood_halos_from_deps(deps, idx, chunk_size, neigh_dist, boundary)
695+
validate_neigh_dist(neigh_dist)
696+
697+
N = length(chunk_size)
698+
chunk_dist = 1
699+
nhalos = 3^N - 1
700+
halos = Vector{Any}(undef, nhalos)
701+
h = 0
702+
703+
for i in 0:(3^N - 1)
704+
region_code = ntuple(N) do d
705+
((i ÷ 3^(d-1)) % 3) - 1
706+
end
707+
all(==(0), region_code) && continue
708+
h += 1
709+
710+
chunk_offset = CartesianIndex(ntuple(N) do d
711+
region_code[d] * chunk_dist
712+
end)
713+
new_idx = idx + chunk_offset
714+
715+
chunk = deps[h+1]
716+
if is_past_boundary(chunk_size, new_idx)
717+
boundary_dims = ntuple(N) do d
718+
new_idx[d] < 1 || new_idx[d] > chunk_size[d]
719+
end
720+
halos[h] = load_boundary_region(boundary, chunk, region_code, neigh_dist, boundary_dims)
721+
else
722+
halos[h] = load_neighbor_region(chunk, region_code, neigh_dist)
723+
end
724+
end
725+
726+
@assert h == nhalos
727+
return Tuple(halos)
728+
end
729+
730+
function select_neighborhood_chunk_deps(chunks, idx, neigh_dist, boundary)
731+
validate_neigh_dist(neigh_dist)
732+
733+
N = ndims(chunks)
734+
chunk_dist = 1
735+
736+
accesses = Any[chunks[idx]]
737+
738+
for i in 0:(3^N - 1)
739+
region_code = ntuple(N) do d
740+
((i ÷ 3^(d-1)) % 3) - 1
741+
end
742+
all(==(0), region_code) && continue
743+
744+
chunk_offset = CartesianIndex(ntuple(N) do d
745+
region_code[d] * chunk_dist
746+
end)
747+
new_idx = idx + chunk_offset
748+
749+
if is_past_boundary(size(chunks), new_idx)
750+
if boundary_has_transition(boundary)
751+
new_idx = boundary_transition(boundary, new_idx, size(chunks))
752+
else
753+
new_idx = idx
754+
end
755+
end
756+
push!(accesses, chunks[new_idx])
757+
end
758+
759+
@assert length(accesses) == 3^N
760+
return accesses
761+
end
762+
763+
function build_chunk_halo(neigh_dist, boundary, idx, chunk_size, own_center::Bool, read_deps...)
764+
center = read_deps[1]
765+
halos = load_neighborhood_halos_from_deps(read_deps, idx, chunk_size, neigh_dist, boundary)
766+
return build_halo(neigh_dist, boundary, center, halos...; own_center=own_center)
767+
end
768+
652769
function select_neighborhood_chunks(chunks, idx, neigh_dist, boundary)
653770
validate_neigh_dist(neigh_dist)
654771

@@ -698,7 +815,7 @@ end
698815

699816
# Returns (region_metadata, neighbor_chunk_dtasks) without spawning intermediate load tasks.
700817
# region_metadata: Vector of (region_code, is_boundary, boundary_dims).
701-
# neighbor_chunk_dtasks: Vector of raw chunk DTasks (resolved to arrays when build_halo_consolidated runs).
818+
# neighbor_chunk_dtasks: Vector of raw chunk DTasks (resolved to arrays when build_halo_new runs).
702819
function select_neighborhood_info(chunks, idx, neigh_dist, boundary)
703820
validate_neigh_dist(neigh_dist)
704821
N = ndims(chunks)
@@ -782,7 +899,7 @@ function build_halo_new(neigh_dist, boundary, center, region_metadata, neighbor_
782899
is_boundary ? load_boundary_region(boundary, chunk, region_code, neigh_dist, boundary_dims) :
783900
load_neighbor_region(chunk, region_code, neigh_dist)
784901
end
785-
return HaloArray(copy(center), halos, halo_width)
902+
return HaloArray(copy(center), halos, halo_width; own_center=true)
786903
end
787904

788905
# Cache-hit path: fill an existing HaloArray in-place and return it. No cache operations —
@@ -803,11 +920,12 @@ function fill_halo_inplace!(halo::HaloArray, neigh_dist, boundary, center, regio
803920
return halo
804921
end
805922

806-
function build_halo(neigh_dist, boundary, center, all_halos...)
923+
function build_halo(neigh_dist, boundary, center, all_halos...; own_center::Bool=false)
807924
N = ndims(center)
808925
expected_halos = 3^N - 1
809926
@assert length(all_halos) == expected_halos "Halo mismatch: N=$N expected $expected_halos halos, got $(length(all_halos))"
810-
return HaloArray(copy(center), (all_halos...,), ntuple(i->get_neigh_dist(neigh_dist, i), N))
927+
center_data = own_center ? copy(center) : center
928+
return HaloArray(center_data, (all_halos...,), ntuple(i->get_neigh_dist(neigh_dist, i), N); own_center)
811929
end
812930

813931
function load_neighborhood(arr::HaloArray{T,N}, idx) where {T,N}

src/utils/haloarray.jl

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@ struct HaloArray{T,N,A<:AbstractArray{T,N},H<:Tuple} <: AbstractArray{T,N}
99
center::A
1010
halos::H # Tuple of 3^N - 1 arrays in canonical order
1111
halo_width::NTuple{N,Int}
12+
own_center::Bool
13+
end
14+
15+
function HaloArray(center, halos::Tuple, halo_width::NTuple{N,Int}; own_center::Bool=false) where N
16+
T = eltype(center)
17+
return HaloArray{T,N,typeof(center),typeof(halos)}(center, halos, halo_width, own_center)
1218
end
1319

1420
# Number of halo regions for N dimensions
@@ -63,7 +69,7 @@ function HaloArray{T,N}(center_size::NTuple{N,Int}, halo_width::NTuple{N,Int}) w
6369
Array{T,N}(undef, region_size...)
6470
end
6571

66-
return HaloArray{T,N,typeof(center),typeof(halos)}(center, halos, halo_width)
72+
return HaloArray(center, halos, halo_width; own_center=true)
6773
end
6874

6975
Base.size(tile::HaloArray) = size(tile.center) .+ 2 .* tile.halo_width
@@ -83,7 +89,7 @@ function Base.copy(tile::HaloArray{T,N}) where {T,N}
8389
center = copy(tile.center)
8490
halos = ntuple(i -> copy(tile.halos[i]), length(tile.halos))
8591
halo_width = tile.halo_width
86-
return HaloArray(center, halos, halo_width)
92+
return HaloArray(center, halos, halo_width; own_center=true)
8793
end
8894

8995
# Compute the region code for a given index
@@ -182,7 +188,8 @@ end
182188
Adapt.adapt_structure(to, H::Dagger.HaloArray) =
183189
HaloArray(Adapt.adapt(to, H.center),
184190
Adapt.adapt.(Ref(to), H.halos),
185-
H.halo_width)
191+
H.halo_width;
192+
own_center=H.own_center)
186193

187194
function aliasing(A::HaloArray)
188195
return CombinedAliasing([aliasing(A.center), map(aliasing, A.halos)...])
@@ -193,16 +200,17 @@ function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::P
193200
center_chunk = move_rewrap(cache, from_proc, to_proc, from_space, to_space, A.center)
194201
halo_chunks = ntuple(i -> move_rewrap(cache, from_proc, to_proc, from_space, to_space, A.halos[i]), length(A.halos))
195202
halo_width = A.halo_width
203+
own_center = A.own_center
196204
to_w = root_worker_id(to_proc)
197-
return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, center_chunk, halo_chunks, halo_width) do from_proc, to_proc, from_space, to_space, center_chunk, halo_chunks, halo_width
205+
return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, center_chunk, halo_chunks, halo_width, own_center) do from_proc, to_proc, from_space, to_space, center_chunk, halo_chunks, halo_width, own_center
198206
center_new = unwrap(center_chunk)
199207
halos_new = ntuple(i -> unwrap(halo_chunks[i]), length(halo_chunks))
200-
return tochunk(HaloArray(center_new, halos_new, halo_width), to_proc)
208+
return tochunk(HaloArray(center_new, halos_new, halo_width; own_center=own_center), to_proc)
201209
end
202210
end
203211

204212
function Dagger.unsafe_free!(A::HaloArray)
205-
unsafe_free!(A.center)
213+
A.own_center && unsafe_free!(A.center)
206214
foreach(unsafe_free!, A.halos)
207215
end
208216

0 commit comments

Comments
 (0)