Skip to content

Commit 7e5ece5

Browse files
Persist FractionalIndices with two separate arrays for i, j (#371)
* one missing allowscalar * small cleanup * bugfix * New way of constructing fractional indices * add topoology * Tuple of fractional indices * Bugfix * Bump ClimaSeaICe * Fix bug in atmos interpolation * Try to generalize to doubly-Flat grids * Bug in StateExchanger construcotr * Bump Oceananigans compat * Update Project.toml * Update Project.toml * fix test simulation * Fix interpolation for Flat topologies? * Ohhh * Fix test bug * Better name * Name mangle * Fix test * Tiny change in MixedLayerDepth diagnostic --------- Co-authored-by: Gregory Wagner <[email protected]> Co-authored-by: Gregory Wagner <[email protected]>
1 parent fa37f3c commit 7e5ece5

File tree

5 files changed

+68
-43
lines changed

5 files changed

+68
-43
lines changed

Project.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ Thermodynamics = "b60c26fb-14c3-4610-9d3e-2d17fe7ff00c"
3333
Adapt = "4"
3434
CFTime = "0.1"
3535
CUDA = "4, 5"
36-
ClimaSeaIce = "0.2.0"
36+
ClimaSeaIce = "0.2.1"
3737
CubicSplines = "0.2"
3838
DataDeps = "0.7"
3939
Downloads = "1.6"
@@ -42,7 +42,7 @@ JLD2 = "0.4, 0.5"
4242
KernelAbstractions = "0.9"
4343
MPI = "0.20"
4444
NCDatasets = "0.12, 0.13, 0.14"
45-
Oceananigans = "0.95.17 - 0.99"
45+
Oceananigans = "0.95.18 - 0.99"
4646
OffsetArrays = "1.14"
4747
OrthogonalSphericalShellGrids = "0.2.2"
4848
Scratch = "1"

src/OceanSeaIceModels/InterfaceComputations/component_interfaces.jl

+27-11
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ using StaticArrays
22
using Thermodynamics
33
using SurfaceFluxes
44
using OffsetArrays
5-
using CUDA: @allowscalar
65

76
using ..OceanSeaIceModels: reference_density,
87
heat_capacity,
@@ -15,7 +14,7 @@ using ..OceanSeaIceModels: reference_density,
1514
using ClimaSeaIce: SeaIceModel
1615

1716
using Oceananigans: HydrostaticFreeSurfaceModel, architecture
18-
using Oceananigans.Grids: inactive_node, node
17+
using Oceananigans.Grids: inactive_node, node, topology
1918
using Oceananigans.BoundaryConditions: fill_halo_regions!
2019

2120
using Oceananigans.Fields: ConstantField, interpolate, FractionalIndices
@@ -60,6 +59,11 @@ struct StateExchanger{G, AST, AEX}
6059
atmosphere_exchanger :: AEX
6160
end
6261

62+
# Note that Field location can also affect fractional index type.
63+
# Here we assume that we know the location of Fields that will be interpolated.
64+
fractional_index_type(FT, Topo) = FT
65+
fractional_index_type(FT, ::Flat) = Nothing
66+
6367
function StateExchanger(ocean::Simulation, atmosphere)
6468
# TODO: generalize this
6569
exchange_grid = ocean.model.grid
@@ -80,13 +84,14 @@ function StateExchanger(ocean::Simulation, atmosphere)
8084
arch = architecture(exchange_grid)
8185
Nx, Ny, Nz = size(exchange_grid)
8286

83-
# Make an array of FractionalIndices
84-
kᴺ = size(exchange_grid, 3)
85-
@allowscalar X1 = _node(1, 1, kᴺ + 1, exchange_grid, c, c, f)
86-
i1 = FractionalIndices(X1, atmosphere.grid, c, c, nothing)
87-
frac_indices_data = [deepcopy(i1) for i=1:Nx+2, j=1:Ny+2, k=1:1]
88-
frac_indices = OffsetArray(frac_indices_data, -1, -1, 0)
89-
frac_indices = on_architecture(arch, frac_indices)
87+
# Make a NamedTuple of fractional indices
88+
# Note: we could use an array of FractionalIndices. Instead, for compatbility
89+
# with Reactant we construct FractionalIndices on the fly in `interpolate_atmospheric_state`.
90+
FT = eltype(atmos_grid)
91+
TX, TY, TZ = topology(exchange_grid)
92+
fi = TX() isa Flat ? nothing : Field{Center, Center, Nothing}(exchange_grid, FT)
93+
fj = TY() isa Flat ? nothing : Field{Center, Center, Nothing}(exchange_grid, FT)
94+
frac_indices = (i=fi, j=fj) # no k needed, only horizontal interpolation
9095

9196
kernel_parameters = interface_kernel_parameters(exchange_grid)
9297
launch!(arch, exchange_grid, kernel_parameters,
@@ -95,11 +100,22 @@ function StateExchanger(ocean::Simulation, atmosphere)
95100
return StateExchanger(ocean.model.grid, exchange_atmosphere_state, frac_indices)
96101
end
97102

98-
@kernel function _compute_fractional_indices!(frac_indices, exchange_grid, atmos_grid)
103+
@kernel function _compute_fractional_indices!(indices_tuple, exchange_grid, atmos_grid)
99104
i, j = @index(Global, NTuple)
100105
kᴺ = size(exchange_grid, 3) # index of the top ocean cell
101106
X = _node(i, j, kᴺ + 1, exchange_grid, c, c, f)
102-
@inbounds frac_indices[i, j, 1] = FractionalIndices(X, atmos_grid, c, c, nothing)
107+
fractional_indices_ij = FractionalIndices(X, atmos_grid, c, c, nothing)
108+
fi = indices_tuple.i
109+
fj = indices_tuple.j
110+
@inbounds begin
111+
if !isnothing(fi)
112+
fi[i, j, 1] = fractional_indices_ij.i
113+
end
114+
115+
if !isnothing(fj)
116+
fj[i, j, 1] = fractional_indices_ij.j
117+
end
118+
end
103119
end
104120

105121
const PATP = PrescribedAtmosphereThermodynamicsParameters

src/OceanSeaIceModels/InterfaceComputations/interpolate_atmospheric_state.jl

+26-17
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using Oceananigans.Operators: intrinsic_vector
22
using Oceananigans.Grids: _node
3+
using Oceananigans.Fields: FractionalIndices
34
using Oceananigans.OutputReaders: TimeInterpolator
45

56
using ...OceanSimulations: forcing_barotropic_potential
@@ -109,6 +110,9 @@ function interpolate_atmospheric_state!(coupled_model)
109110
parent(barotropic_potential) .= parent(atmosphere_data.p) ./ ρₒ
110111
end
111112
end
113+
114+
@inline get_fractional_index(i, j, ::Nothing) = nothing
115+
@inline get_fractional_index(i, j, frac) = @inbounds frac[i, j, 1]
112116

113117
@kernel function _interpolate_primary_atmospheric_state!(surface_atmos_state,
114118
space_fractional_indices,
@@ -124,28 +128,33 @@ end
124128

125129
i, j = @index(Global, NTuple)
126130

127-
@inbounds begin
128-
x_itp = space_fractional_indices[i, j, 1]
129-
t_itp = time_interpolator
130-
atmos_args = (x_itp, t_itp, atmos_backend, atmos_time_indexing)
131+
ii = space_fractional_indices.i
132+
jj = space_fractional_indices.j
133+
fi = get_fractional_index(i, j, ii)
134+
fj = get_fractional_index(i, j, jj)
131135

132-
uₐ = interp_atmos_time_series(atmos_velocities.u, atmos_args...)
133-
vₐ = interp_atmos_time_series(atmos_velocities.v, atmos_args...)
134-
Tₐ = interp_atmos_time_series(atmos_tracers.T, atmos_args...)
135-
qₐ = interp_atmos_time_series(atmos_tracers.q, atmos_args...)
136-
pₐ = interp_atmos_time_series(atmos_pressure, atmos_args...)
136+
x_itp = FractionalIndices(fi, fj, nothing)
137+
t_itp = time_interpolator
138+
atmos_args = (x_itp, t_itp, atmos_backend, atmos_time_indexing)
137139

138-
Qs = interp_atmos_time_series(downwelling_radiation.shortwave, atmos_args...)
139-
Qℓ = interp_atmos_time_series(downwelling_radiation.longwave, atmos_args...)
140+
uₐ = interp_atmos_time_series(atmos_velocities.u, atmos_args...)
141+
vₐ = interp_atmos_time_series(atmos_velocities.v, atmos_args...)
142+
Tₐ = interp_atmos_time_series(atmos_tracers.T, atmos_args...)
143+
qₐ = interp_atmos_time_series(atmos_tracers.q, atmos_args...)
144+
pₐ = interp_atmos_time_series(atmos_pressure, atmos_args...)
140145

141-
# Usually precipitation
142-
Mh = interp_atmos_time_series(prescribed_freshwater_flux, atmos_args...)
146+
Qs = interp_atmos_time_series(downwelling_radiation.shortwave, atmos_args...)
147+
Qℓ = interp_atmos_time_series(downwelling_radiation.longwave, atmos_args...)
143148

144-
# Convert atmosphere velocities (usually defined on a latitude-longitude grid) to
145-
# the frame of reference of the native grid
146-
kᴺ = size(exchange_grid, 3) # index of the top ocean cell
147-
uₐ, vₐ = intrinsic_vector(i, j, kᴺ, exchange_grid, uₐ, vₐ)
149+
# Usually precipitation
150+
Mh = interp_atmos_time_series(prescribed_freshwater_flux, atmos_args...)
148151

152+
# Convert atmosphere velocities (usually defined on a latitude-longitude grid) to
153+
# the frame of reference of the native grid
154+
kᴺ = size(exchange_grid, 3) # index of the top ocean cell
155+
uₐ, vₐ = intrinsic_vector(i, j, kᴺ, exchange_grid, uₐ, vₐ)
156+
157+
@inbounds begin
149158
surface_atmos_state.u[i, j, 1] = uₐ
150159
surface_atmos_state.v[i, j, 1] = vₐ
151160
surface_atmos_state.T[i, j, 1] = Tₐ

test/test_diagnostics.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,11 @@ using ClimaOcean.Diagnostics: MixedLayerDepthField, MixedLayerDepthOperand
4242
@test h.operand.buoyancy_perturbation isa KernelFunctionOperation
4343

4444
compute!(h)
45-
@test @allowscalar h[1, 1, 1] 16.255836 # m
45+
@test @allowscalar h[1, 1, 1] 16.2558363 # m
4646

4747
tracers = (T=Tt[2], S=St[2])
4848
h.operand.buoyancy_perturbation = buoyancy(sb, grid, tracers)
4949
compute!(h)
50-
@test @allowscalar h[1, 1, 1] 9.295890287 # m
50+
@test @allowscalar h[1, 1, 1] 9.2957298 # m
5151
end
5252
end

test/test_simulations.jl

+11-11
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ using OrthogonalSphericalShellGrids
55
using ClimaOcean.OceanSeaIceModels: above_freezing_ocean_temperature!
66
using ClimaSeaIce.SeaIceThermodynamics: melting_temperature
77

8-
@testset "GPU time stepping test" begin
8+
@inline kernel_melting_temperature(i, j, k, grid, liquidus, S) = @inbounds melting_temperature(liquidus, S[i, j, k])
9+
10+
@testset "Time stepping test" begin
911

1012
for arch in test_architectures
1113

@@ -24,9 +26,9 @@ using ClimaSeaIce.SeaIceThermodynamics: melting_temperature
2426
interpolation_passes = 20,
2527
major_basins = 1)
2628

27-
grid = ImmersedBoundaryGrid(grid, GridFittedBottom(bottom_height); active_cells_map = true)
29+
grid = ImmersedBoundaryGrid(grid, GridFittedBottom(bottom_height); active_cells_map=true)
2830

29-
free_surface = SplitExplicitFreeSurface(grid; substeps = 20)
31+
free_surface = SplitExplicitFreeSurface(grid; substeps=20)
3032
ocean = ocean_simulation(grid; free_surface)
3133

3234
backend = JRA55NetCDFBackend(4)
@@ -52,24 +54,22 @@ using ClimaSeaIce.SeaIceThermodynamics: melting_temperature
5254

5355
# Set the ocean temperature and salinity
5456
set!(ocean.model, T=temperature_metadata[1], S=salinity_metadata[1])
55-
5657
above_freezing_ocean_temperature!(ocean, sea_ice)
5758

5859
# Test that ocean temperatures are above freezing
59-
T = on_architecture(CPU(), ocean.model.T)
60-
S = on_architecture(CPU(), ocean.model.S)
61-
62-
@inline pointwise_melting_T(i, j, k, grid, liquidus, S) = @inbounds melting_temperature(liquidus, S[i, j, k])
60+
T = on_architecture(CPU(), ocean.model.tracers.T)
61+
S = on_architecture(CPU(), ocean.model.tracers.S)
6362

64-
Tm = KernelFunctionOperation{Center, Center, Center}(pointwise_melting_T, grid, S)
65-
66-
@test all(T .> Tm)
63+
Tm = KernelFunctionOperation{Center, Center, Center}(kernel_melting_temperature, grid, liquidus, S)
64+
@test all(T .>= Tm)
6765

6866
# Fluxes are computed when the model is constructed, so we just test that this works.
6967
# And that we can time step with sea ice
7068
@test begin
7169
coupled_model = OceanSeaIceModel(ocean, sea_ice; atmosphere, radiation)
70+
time_step!(coupled_model, 1)
7271
true
7372
end
7473
end
7574
end
75+

0 commit comments

Comments
 (0)