Skip to content

Add distributed reductions #4497

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
May 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions src/DistributedComputations/distributed_fields.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
using Oceananigans.Grids: topology
using Oceananigans.Fields: validate_field_data, indices, validate_boundary_conditions
using Oceananigans.Fields: validate_indices, set_to_array!, set_to_field!
using CUDA: @allowscalar

using Oceananigans.Fields: ReducedAbstractField,
get_neutral_mask,
condition_operand,
initialize_reduced_field!,
filltype,
reduced_dimensions,
reduced_location

import Oceananigans.Fields: Field, location, set!
import Oceananigans.BoundaryConditions: fill_halo_regions!
Expand Down Expand Up @@ -94,3 +103,97 @@ function reconstruct_global_field(field::DistributedField)

return global_field
end

"""
partition_dimensions(arch::Distributed)
partition_dimensions(f::DistributedField)

Return the partitioned dimensions of a distributed field or architecture.
"""
function partition_dimensions(arch::Distributed)
R = ranks(arch)
dims = []
for r in eachindex(R)
if R[r] > 1
push!(dims, r)
end
end
return tuple(dims...)
end

partition_dimensions(f::DistributedField) = partition_dimensions(architecture(f))

function maybe_all_reduce!(op, f::ReducedAbstractField)
reduced_dims = reduced_dimensions(f)
partition_dims = partition_dimensions(f)

if any([dim ∈ partition_dims for dim in reduced_dims])
all_reduce!(op, parent(f), architecture(f))
end

return f
end

# Allocating and in-place reductions
for (reduction, all_reduce_op) in zip((:sum, :maximum, :minimum, :all, :any, :prod),
(:+, :max, :min, :&, :|, :*))

reduction! = Symbol(reduction, '!')

@eval begin
# In-place
function Base.$(reduction!)(f::Function,
r::ReducedAbstractField,
a::DistributedField;
condition = nothing,
mask = get_neutral_mask(Base.$(reduction!)),
kwargs...)

operand = condition_operand(f, a, condition, mask)

Base.$(reduction!)(identity,
interior(r),
operand;
kwargs...)

return maybe_all_reduce!($(all_reduce_op), r)
end

function Base.$(reduction!)(r::ReducedAbstractField,
a::DistributedField;
condition = nothing,
mask = get_neutral_mask(Base.$(reduction!)),
kwargs...)

Base.$(reduction!)(identity,
interior(r),
condition_operand(a, condition, mask);
kwargs...)

return maybe_all_reduce!($(all_reduce_op), r)
end

# Allocating
function Base.$(reduction)(f::Function,
c::DistributedField;
condition = nothing,
mask = get_neutral_mask(Base.$(reduction!)),
dims = :)

conditioned_c = condition_operand(f, c, condition, mask)
T = filltype(Base.$(reduction!), c)
loc = reduced_location(location(c); dims)
r = Field(loc, c.grid, T; indices=indices(c))
initialize_reduced_field!(Base.$(reduction!), identity, r, conditioned_c)
Base.$(reduction!)(identity, interior(r), conditioned_c, init=false)

maybe_all_reduce!($(all_reduce_op), r)

if dims isa Colon
return @allowscalar first(r)
else
return r
end
end
end
end
2 changes: 1 addition & 1 deletion src/Fields/field.jl
Original file line number Diff line number Diff line change
Expand Up @@ -727,7 +727,7 @@ for reduction in (:sum, :maximum, :minimum, :all, :any, :prod)
end
end

# Improve me! We can should both the extrama in one single reduction instead of two
# Improve me! We can should both the extrema in one single reduction instead of two
Base.extrema(c::AbstractField; kwargs...) = (minimum(c; kwargs...), maximum(c; kwargs...))
Base.extrema(f, c::AbstractField; kwargs...) = (minimum(f, c; kwargs...), maximum(f, c; kwargs...))

Expand Down
35 changes: 33 additions & 2 deletions test/test_distributed_models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ include_corners ? view(f.data, :, :, right_halo_indices(instantiate(LZ), instant
interior_indices(instantiate(LY), instantiate(topology(f, 2)), f.grid.Ny),
right_halo_indices(instantiate(LZ), instantiate(topology(f, 3)), f.grid.Nz, f.grid.Hz))


function southwest_halo(f::AbstractField)
Nx, Ny, _ = size(f.grid)
Hx, Hy, _ = halo_size(f.grid)
Expand Down Expand Up @@ -409,7 +408,6 @@ end

@testset "Distributed MPI Oceananigans" begin
@info "Testing distributed MPI Oceananigans..."

@testset "Multi architectures rank connectivity" begin
@info " Testing multi architecture rank connectivity..."
test_triply_periodic_rank_connectivity_with_411_ranks()
Expand Down Expand Up @@ -464,6 +462,39 @@ end
end
end

@testset "Distributed reductions" begin
child_arch = get(ENV, "TEST_ARCHITECTURE", "CPU") == "GPU" ? GPU() : CPU()

for partition in [Partition(1, 4), Partition(2, 2), Partition(4, 1)]
@info "Time-stepping a distributed NonhydrostaticModel with partition $partition..."
arch = Distributed(child_arch; partition)
grid = RectilinearGrid(arch, topology=(Periodic, Periodic, Periodic), size=(8, 8, 1), extent=(1, 2, 3))
c = CenterField(grid)
set!(c, arch.local_rank+1)

c_reduced = Field{Nothing, Nothing, Nothing}(grid)

N = grid.Nx * grid.Ny # local rank grid size
@test sum(c) == 1*N + 2*N + 3*N + 4*N

sum!(c_reduced, c)
@test CUDA.@allowscalar c_reduced[1, 1, 1] == 1*N + 2*N + 3*N + 4*N

cbool = CenterField(grid, Bool)
cbool_reduced = Field{Nothing, Nothing, Nothing}(grid, Bool)
bool_val = arch.local_rank == 0 ? true : false
set!(cbool, bool_val)

@test any(cbool) == true
@test all(cbool) == false

any!(cbool_reduced, cbool)
@test CUDA.@allowscalar cbool_reduced[1, 1, 1] == true

all!(cbool_reduced, cbool)
@test CUDA.@allowscalar cbool_reduced[1, 1, 1] == false
end
end

# Only test on CPU because we do not have a GPU pressure solver yet
@testset "Time stepping NonhydrostaticModel" begin
Expand Down