Skip to content

Commit b2426e0

Browse files
glwagnersimone-silvestrigiordano
authored
Pad Center fields on BoundedTopology so that all Fields have identical horizontal size (#4144)
* Pad Center fields on BoundedTopology to even out all fields * Change abstraction * Better hack * sharded grid * add sharding length * no need for this * Update ext/OceananigansReactantExt/Grids/sharded_grids.jl * using instantiate * remove the show * Update src/Grids/grid_utils.jl * Extend total_size * Update src/Grids/new_data.jl * sharding_total_length -> reactant_total_length (#4329) * Fix tests now that padding differs between reactant and vanilla --------- Co-authored-by: Simone Silvestri <[email protected]> Co-authored-by: Mosè Giordano <[email protected]>
1 parent fe5f37b commit b2426e0

File tree

5 files changed

+111
-72
lines changed

5 files changed

+111
-72
lines changed

ext/OceananigansReactantExt/Grids/Grids.jl

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,19 @@ module Grids
33
export constant_with_arch
44

55
using Reactant
6+
using OffsetArrays
67

78
using Oceananigans
89
using Oceananigans: Distributed
910
using Oceananigans.Architectures: ReactantState, CPU
1011
using Oceananigans.Grids: AbstractGrid, AbstractUnderlyingGrid, StaticVerticalDiscretization, MutableVerticalDiscretization
11-
using Oceananigans.Grids: Center, Face, RightConnected, LeftConnected, Periodic, Bounded, Flat
12+
using Oceananigans.Grids: Center, Face, RightConnected, LeftConnected, Periodic, Bounded, Flat, BoundedTopology
1213
using Oceananigans.Fields: Field
1314
using Oceananigans.ImmersedBoundaries: GridFittedBottom, AbstractImmersedBoundary
1415

1516
import ..OceananigansReactantExt: deconcretize
1617
import Oceananigans.Grids: LatitudeLongitudeGrid, RectilinearGrid, OrthogonalSphericalShellGrid
18+
import Oceananigans.Grids: total_length, offset_data
1719
import Oceananigans.OrthogonalSphericalShellGrids: RotatedLatitudeLongitudeGrid, TripolarGrid
1820
import Oceananigans.ImmersedBoundaries: ImmersedBoundaryGrid, materialize_immersed_boundary
1921

@@ -39,5 +41,53 @@ const ShardedGrid{FT, TX, TY, TZ} = AbstractGrid{FT, TX, TY, TZ, <:ShardedDistri
3941
include("serial_grids.jl")
4042
include("sharded_grids.jl")
4143

44+
function total_size(grid::ReactantGrid, loc, indices)
45+
sz = size(grid)
46+
halo_sz = halo_size(grid)
47+
topo = topology(grid)
48+
return reactant_total_size(loc, topo, sz, halo_sz, indices)
49+
end
50+
51+
function reactant_total_size(loc, topo, sz, halo_sz, indices=default_indices(Val(length(loc))))
52+
D = length(loc)
53+
return Tuple(reactant_total_length(instantiate(loc[d]), instantiate(topo[d]), sz[d], halo_sz[d], indices[d]) for d = 1:D)
54+
end
55+
56+
reactant_total_length(loc, topo, N, H, ::Colon) = reactant_total_length(loc, topo, N, H)
57+
reactant_total_length(loc, topo, N, H, ind::AbstractUnitRange) = min(reactant_total_length(loc, topo, N, H), length(ind))
58+
reactant_total_length(loc, topo, N, H) = Oceananigans.Grids.total_length(loc, topo, N, H)
59+
reactant_total_length(::Face, ::BoundedTopology, N, H=0) = N + 2H
60+
61+
reactant_offset_indices(loc, topo, N, H=0) = 1 - H : N + H
62+
reactant_offset_indices(::Nothing, topo, N, H=0) = 1:1
63+
reactant_offset_indices(ℓ, topo, N, H, ::Colon) = reactant_offset_indices(ℓ, topo, N, H)
64+
reactant_offset_indices(ℓ, topo, N, H, r::AbstractUnitRange) = r
65+
reactant_offset_indices(::Nothing, topo, N, H, ::AbstractUnitRange) = 1:1
66+
67+
function Oceananigans.Grids.new_data(FT::DataType, arch::Union{ReactantState, ShardedDistributed},
68+
loc, topo, sz, halo_sz, indices=default_indices(length(loc)))
69+
70+
Tsz = reactant_total_size(loc, topo, sz, halo_sz, indices)
71+
underlying_data = zeros(arch, FT, Tsz...)
72+
indices = validate_indices(indices, loc, topo, sz, halo_sz)
73+
74+
return offset_data(underlying_data, loc, topo, sz, halo_sz, indices)
75+
end
76+
77+
# The type parameter for indices helps / encourages the compiler to fully type infer `offset_data`
78+
function offset_data(underlying_data::ConcreteRArray, loc, topo, N, H, indices::T=default_indices(length(loc))) where T
79+
loc = map(instantiate, loc)
80+
topo = map(instantiate, topo)
81+
ii = map(reactant_offset_indices, loc, topo, N, H, indices)
82+
# Add extra indices for arrays of higher dimension than loc, topo, etc.
83+
# Use the "`ntuple` trick" so the compiler can infer the type of `extra_ii`
84+
extra_ii = ntuple(Val(ndims(underlying_data)-length(ii))) do i
85+
Base.@_inline_meta
86+
axes(underlying_data, i+length(ii))
87+
end
88+
89+
return OffsetArray(underlying_data, ii..., extra_ii...)
90+
end
91+
4292
end # module
4393

ext/OceananigansReactantExt/Grids/sharded_grids.jl

Lines changed: 23 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@ using Oceananigans.Architectures: architecture
22
using Oceananigans.Grids: AbstractGrid
33
using Oceananigans.OrthogonalSphericalShellGrids
44
using Oceananigans.Grids: R_Earth, validate_lat_lon_grid_args, generate_coordinate, with_precomputed_metrics, validate_rectilinear_grid_args
5+
using Oceananigans.Grids: default_indices, validate_indices, offset_data, instantiate, halo_size, topology
56

6-
import Oceananigans.Grids: zeros, StaticVerticalDiscretization
7+
import Oceananigans.Grids: zeros, StaticVerticalDiscretization, total_size
78
import Oceananigans.Architectures: child_architecture
89

910
import Oceananigans.DistributedComputations:
@@ -214,64 +215,38 @@ function TripolarGrid(arch::ShardedDistributed,
214215

215216
# Needed for partitial array assembly
216217
# device_to_array_slices = Reactant.sharding_to_array_slices(sharding, global_size)
217-
irange = Colon()
218-
jrange = Colon()
219218
FT = eltype(global_grid)
220219

221-
# Partitioning the Coordinates
222-
λᶠᶠᵃ = OrthogonalSphericalShellGrids.partition_tripolar_metric(global_grid, :λᶠᶠᵃ, irange, jrange)
223-
φᶠᶠᵃ = OrthogonalSphericalShellGrids.partition_tripolar_metric(global_grid, :φᶠᶠᵃ, irange, jrange)
224-
λᶠᶜᵃ = OrthogonalSphericalShellGrids.partition_tripolar_metric(global_grid, :λᶠᶜᵃ, irange, jrange)
225-
φᶠᶜᵃ = OrthogonalSphericalShellGrids.partition_tripolar_metric(global_grid, :φᶠᶜᵃ, irange, jrange)
226-
λᶜᶠᵃ = OrthogonalSphericalShellGrids.partition_tripolar_metric(global_grid, :λᶜᶠᵃ, irange, jrange)
227-
φᶜᶠᵃ = OrthogonalSphericalShellGrids.partition_tripolar_metric(global_grid, :φᶜᶠᵃ, irange, jrange)
228-
λᶜᶜᵃ = OrthogonalSphericalShellGrids.partition_tripolar_metric(global_grid, :λᶜᶜᵃ, irange, jrange)
229-
φᶜᶜᵃ = OrthogonalSphericalShellGrids.partition_tripolar_metric(global_grid, :φᶜᶜᵃ, irange, jrange)
230-
231-
# # Partitioning the Metrics
232-
Δxᶜᶜᵃ = OrthogonalSphericalShellGrids.partition_tripolar_metric(global_grid, :Δxᶜᶜᵃ, irange, jrange)
233-
Δxᶠᶜᵃ = OrthogonalSphericalShellGrids.partition_tripolar_metric(global_grid, :Δxᶠᶜᵃ, irange, jrange)
234-
Δxᶜᶠᵃ = OrthogonalSphericalShellGrids.partition_tripolar_metric(global_grid, :Δxᶜᶠᵃ, irange, jrange)
235-
Δxᶠᶠᵃ = OrthogonalSphericalShellGrids.partition_tripolar_metric(global_grid, :Δxᶠᶠᵃ, irange, jrange)
236-
Δyᶜᶜᵃ = OrthogonalSphericalShellGrids.partition_tripolar_metric(global_grid, :Δyᶜᶜᵃ, irange, jrange)
237-
Δyᶠᶜᵃ = OrthogonalSphericalShellGrids.partition_tripolar_metric(global_grid, :Δyᶠᶜᵃ, irange, jrange)
238-
Δyᶜᶠᵃ = OrthogonalSphericalShellGrids.partition_tripolar_metric(global_grid, :Δyᶜᶠᵃ, irange, jrange)
239-
Δyᶠᶠᵃ = OrthogonalSphericalShellGrids.partition_tripolar_metric(global_grid, :Δyᶠᶠᵃ, irange, jrange)
240-
Azᶜᶜᵃ = OrthogonalSphericalShellGrids.partition_tripolar_metric(global_grid, :Azᶜᶜᵃ, irange, jrange)
241-
Azᶠᶜᵃ = OrthogonalSphericalShellGrids.partition_tripolar_metric(global_grid, :Azᶠᶜᵃ, irange, jrange)
242-
Azᶜᶠᵃ = OrthogonalSphericalShellGrids.partition_tripolar_metric(global_grid, :Azᶜᶠᵃ, irange, jrange)
243-
Azᶠᶠᵃ = OrthogonalSphericalShellGrids.partition_tripolar_metric(global_grid, :Azᶠᶠᵃ, irange, jrange)
244-
245220
# Copying the z coordinate to all the devices: we pass a NamedSharding of `nothing`s
246221
# (a NamedSharding of nothings represents a copy to all devices)
247222
# ``1'' here is the maximum number of dimensions of the fields of ``z''
248223
replicate = Sharding.NamedSharding(arch.connectivity, ntuple(Returns(nothing), 1))
249224

250-
grid = OrthogonalSphericalShellGrid{Periodic,RightConnected,Bounded}(arch,
225+
grid = OrthogonalSphericalShellGrid{Periodic, RightConnected, Bounded}(arch,
251226
global_size...,
252227
halo...,
253228
convert(FT, global_grid.Lz),
254-
Reactant.to_rarray(λᶜᶜᵃ; sharding),
255-
Reactant.to_rarray(λᶠᶜᵃ; sharding),
256-
Reactant.to_rarray(λᶜᶠᵃ; sharding),
257-
Reactant.to_rarray(λᶠᶠᵃ; sharding),
258-
Reactant.to_rarray(φᶜᶜᵃ; sharding),
259-
Reactant.to_rarray(φᶠᶜᵃ; sharding),
260-
Reactant.to_rarray(φᶜᶠᵃ; sharding),
261-
Reactant.to_rarray(φᶠᶠᵃ; sharding),
229+
Reactant.to_rarray(global_grid.λᶜᶜᵃ; sharding),
230+
Reactant.to_rarray(global_grid.λᶠᶜᵃ; sharding),
231+
Reactant.to_rarray(global_grid.λᶜᶠᵃ; sharding),
232+
Reactant.to_rarray(global_grid.λᶠᶠᵃ; sharding),
233+
Reactant.to_rarray(global_grid.φᶜᶜᵃ; sharding),
234+
Reactant.to_rarray(global_grid.φᶠᶜᵃ; sharding),
235+
Reactant.to_rarray(global_grid.φᶜᶠᵃ; sharding),
236+
Reactant.to_rarray(global_grid.φᶠᶠᵃ; sharding),
262237
sharded_z_direction(global_grid.z; sharding=replicate), # Replicating on all devices
263-
Reactant.to_rarray(Δxᶜᶜᵃ; sharding),
264-
Reactant.to_rarray(Δxᶠᶜᵃ; sharding),
265-
Reactant.to_rarray(Δxᶜᶠᵃ; sharding),
266-
Reactant.to_rarray(Δxᶠᶠᵃ; sharding),
267-
Reactant.to_rarray(Δyᶜᶜᵃ; sharding),
268-
Reactant.to_rarray(Δyᶠᶜᵃ; sharding),
269-
Reactant.to_rarray(Δyᶜᶠᵃ; sharding),
270-
Reactant.to_rarray(Δyᶠᶠᵃ; sharding),
271-
Reactant.to_rarray(Azᶜᶜᵃ; sharding),
272-
Reactant.to_rarray(Azᶠᶜᵃ; sharding),
273-
Reactant.to_rarray(Azᶜᶠᵃ; sharding),
274-
Reactant.to_rarray(Azᶠᶠᵃ; sharding),
238+
Reactant.to_rarray(global_grid.Δxᶜᶜᵃ; sharding),
239+
Reactant.to_rarray(global_grid.Δxᶠᶜᵃ; sharding),
240+
Reactant.to_rarray(global_grid.Δxᶜᶠᵃ; sharding),
241+
Reactant.to_rarray(global_grid.Δxᶠᶠᵃ; sharding),
242+
Reactant.to_rarray(global_grid.Δyᶜᶜᵃ; sharding),
243+
Reactant.to_rarray(global_grid.Δyᶠᶜᵃ; sharding),
244+
Reactant.to_rarray(global_grid.Δyᶜᶠᵃ; sharding),
245+
Reactant.to_rarray(global_grid.Δyᶠᶠᵃ; sharding),
246+
Reactant.to_rarray(global_grid.Azᶜᶜᵃ; sharding),
247+
Reactant.to_rarray(global_grid.Azᶠᶜᵃ; sharding),
248+
Reactant.to_rarray(global_grid.Azᶜᶠᵃ; sharding),
249+
Reactant.to_rarray(global_grid.Azᶠᶠᵃ; sharding),
275250
convert(FT, global_grid.radius),
276251
global_grid.conformal_mapping)
277252

src/Grids/abstract_grid.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ const XYZFlatGrid = AbstractGrid{<:Any, Flat, Flat, Flat}
4242
isrectilinear(grid) = false
4343

4444
# Fallback
45-
@inline get_active_column_map(::AbstractGrid) = nothing
45+
@inline get_active_column_map(::AbstractGrid) = nothing
4646
@inline get_active_cells_map(::AbstractGrid, any_map_type) = nothing
4747

4848
"""

src/Grids/vertical_discretization.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ MutableVerticalDiscretization(r_faces) = MutableVerticalDiscretization(r_faces,
5858
####
5959

6060
const RegularStaticVerticalDiscretization = StaticVerticalDiscretization{<:Any, <:Any, <:Number}
61-
const RegularMutableVerticalDiscretization = MutableVerticalDiscretization{<:Any, <:Any, <:Number}
61+
const RegularMutableVerticalDiscretization = MutableVerticalDiscretization{<:Any, <:Any, <:Number}
6262

6363
const RegularVerticalCoordinate = Union{RegularStaticVerticalDiscretization, RegularMutableVerticalDiscretization}
6464

test/reactant_test_utils.jl

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -60,23 +60,29 @@ function test_reactant_model_correctness(GridType, ModelType, grid_kw, model_kw;
6060
u, v, w = model.velocities
6161
ru, rv, rw = r_model.velocities
6262

63-
# They will not be equal because r_model halos are not
64-
# filled during set!
65-
@test !(parent(u) parent(ru))
66-
@test !(parent(v) parent(rv))
67-
@test !(parent(w) parent(rw))
68-
63+
# Note that r_model halos are not filled during set!
64+
# It's complicated to test this currently because the halo
65+
# regions have different paddings, so we don't do it.
66+
6967
Oceananigans.TimeSteppers.update_state!(r_model)
7068

7169
# Test that fields were set correctly
7270
@info " After setting an initial condition:"
73-
@show maximum(abs.(parent(u) .- parent(ru)))
74-
@show maximum(abs.(parent(v) .- parent(rv)))
75-
@show maximum(abs.(parent(w) .- parent(rw)))
71+
rui = Array(interior(ru))
72+
rvi = Array(interior(rv))
73+
rwi = Array(interior(rw))
74+
75+
ui = Array(interior(u))
76+
vi = Array(interior(v))
77+
wi = Array(interior(w))
78+
79+
@show maximum(abs.(ui .- rui))
80+
@show maximum(abs.(vi .- rvi))
81+
@show maximum(abs.(wi .- rwi))
7682

77-
@test parent(u) parent(ru)
78-
@test parent(v) parent(rv)
79-
@test parent(w) parent(rw)
83+
@test ui rui
84+
@test vi rvi
85+
@test wi rwi
8086

8187
# Deduce a stable time-step
8288
Δx = minimum_xspacing(grid)
@@ -116,17 +122,25 @@ function test_reactant_model_correctness(GridType, ModelType, grid_kw, model_kw;
116122
@test iteration(r_simulation) == iteration(simulation)
117123
@test time(r_simulation) == time(simulation)
118124

119-
@show maximum(abs, parent(u))
120-
@show maximum(abs, parent(v))
121-
@show maximum(abs, parent(w))
125+
rui = Array(interior(ru))
126+
rvi = Array(interior(rv))
127+
rwi = Array(interior(rw))
128+
129+
ui = Array(interior(u))
130+
vi = Array(interior(v))
131+
wi = Array(interior(w))
132+
133+
@show maximum(abs, ui)
134+
@show maximum(abs, vi)
135+
@show maximum(abs, wi)
122136

123-
@show maximum(abs.(parent(u) .- parent(ru)))
124-
@show maximum(abs.(parent(v) .- parent(rv)))
125-
@show maximum(abs.(parent(w) .- parent(rw)))
137+
@show maximum(abs.(ui .- rui))
138+
@show maximum(abs.(vi .- rvi))
139+
@show maximum(abs.(wi .- rwi))
126140

127-
@test parent(u) parent(ru)
128-
@test parent(v) parent(rv)
129-
@test parent(w) parent(rw)
141+
@test ui rui
142+
@test vi rvi
143+
@test wi rwi
130144

131145
# Running a few more time-steps works too:
132146
r_simulation.stop_iteration += 2

0 commit comments

Comments
 (0)