Skip to content

Commit 8eec083

Browse files
maximilian-gelbrechtwsmoses
authored andcommitted
updated example, simplified, hopefully faster
1 parent 5e1a632 commit 8eec083

File tree

2 files changed

+10
-28
lines changed

2 files changed

+10
-28
lines changed

test/integration/SpeedyWeather/Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
[deps]
2-
Checkpointing = "eb46d486-4f9c-4c3d-b445-a617f2a2f1ca"
32
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
43
SpeedyWeather = "9e226e20-d153-4fed-8a5b-493def4f21a9"
54
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
65

76
[compat]
8-
Checkpointing = "0.11"
97
SpeedyWeather = "0.17.4"
108

119
[sources.Enzyme]

test/integration/SpeedyWeather/runtests.jl

Lines changed: 10 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,9 @@
11
# SpeedyWeather.jl integration example
2-
# Sensitivity Analysis of Temperature at a single grid point (one-hot seed)
3-
# over the full integration of the PrimitiveWetModel over N timesteps
4-
# Note: reducing N, or reducing trunc will not reduce compile time of the gradient
5-
# we could reduce model complexity a bit by excluding some parameterizations
2+
# Sensitivity Analysis of a single time step of the PrimitiveWetModel
63
#
7-
# For the test itself, we test that Enzyme doesn't error and gradients are nonzero
4+
# For the test itself, we test that Enzyme doesn't error and gradients are nonzero and make some physical sense
85

9-
using SpeedyWeather, Enzyme, Checkpointing, Test
10-
11-
# Parse command line argument for N (number of timesteps)
12-
const N = length(ARGS) >= 1 ? parse(Int, ARGS[1]) : 5
6+
using SpeedyWeather, Enzyme, Test
137

148
spectral_grid = SpectralGrid(trunc = 32, nlayers = 8) # define resolution
159
model = PrimitiveWetModel(; spectral_grid, physics=false) # construct model
@@ -28,28 +22,18 @@ diagn = diagnostic_variables
2822
# do the scaling again because we need it for the timestepping when calling it manually
2923
SpeedyWeather.scale!(progn, diagn, model.planet.radius)
3024

31-
function checkpointed_timesteps!(progn::PrognosticVariables, diagn, model, N_steps, checkpoint_scheme::Scheme, lf1 = 2, lf2 = 2)
32-
33-
@ad_checkpoint checkpoint_scheme for _ in 1:N_steps
34-
SpeedyWeather.timestep!(progn, diagn, 2 * model.time_stepping.Δt, model, lf1, lf2)
35-
end
36-
37-
return nothing
38-
end
39-
40-
checkpoint_scheme = Revolve(N)
25+
dprogn = zero(progn)
26+
ddiag = make_zero(diagn)
27+
dmodel = make_zero(model)
4128

4229
# Temperature One-Hot
43-
d_progn = zero(progn)
44-
d_model = make_zero(model)
45-
d_diag = make_zero(diagn)
4630
seed_point = 443 # seed point
47-
d_diag.grid.temp_grid[seed_point, 8] = 1
31+
ddiag.grid.temp_grid[seed_point, 8] = 1
4832

49-
# Sensitivity Analysis of Temperature at a single grid point (one-hot seed)
50-
autodiff(Enzyme.Reverse, checkpointed_timesteps!, Const, Duplicated(progn, d_progn), Duplicated(diagn, d_diag), Duplicated(model, d_model), Const(N), Const(checkpoint_scheme))
33+
# Sensitivity Analysis of Temperature at a single grid point (one-hot seed) for a single timestep
34+
autodiff(Enzyme.Reverse, SpeedyWeather.timestep!, Const, Duplicated(progn, dprogn), Duplicated(diagn, ddiag), Const(dt), Duplicated(model, dmodel))
5135

52-
vor_grid = transform(d_progn.vor[:, :, 2], model.spectral_transform)
36+
vor_grid = transform(dprogn.vor[:, :, 2], model.spectral_transform)
5337

5438
# nonzero
5539
@test sum(abs, vor_grid) > 0

0 commit comments

Comments
 (0)