Skip to content

Commit 110f3b7

Browse files
committed
Add ekp kwarg to pass EKP struct into calibrate
1 parent 5dd782a commit 110f3b7

File tree

7 files changed

+52
-16
lines changed

7 files changed

+52
-16
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ClimaCalibrate"
22
uuid = "4347a170-ebd6-470c-89d3-5c705c0cacc2"
33
authors = ["Climate Modeling Alliance"]
4-
version = "0.0.4"
4+
version = "0.0.5"
55

66
[deps]
77
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"

src/backends.jl

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
using Distributed
22

3+
import EnsembleKalmanProcesses as EKP
4+
35
export get_backend, calibrate, model_run
46

57
abstract type AbstractBackend end
@@ -82,10 +84,15 @@ function calibrate(
8284
::Type{JuliaBackend},
8385
config::ExperimentConfig;
8486
reruns = 0,
87+
ekp = nothing,
8588
ekp_kwargs...,
8689
)
8790
(; n_iterations, output_dir, ensemble_size) = config
88-
eki = initialize(config; ekp_kwargs...)
91+
ekp = if ekp isa EKP.EnsembleKalmanProcess
92+
initialize(ekp, prior, output_dir)
93+
else
94+
initialize(config; ekp_kwargs...)
95+
end
8996
on_error(e::InterruptException) = rethrow(e)
9097
on_error(e) =
9198
@error "Single ensemble member has errored. See stacktrace" exception =
@@ -101,14 +108,15 @@ function calibrate(
101108
terminate = update_ensemble(config, i)
102109
!isnothing(terminate) && break
103110
iter_path = path_to_iteration(output_dir, i + 1)
104-
eki = JLD2.load_object(joinpath(iter_path, "eki_file.jld2"))
111+
ekp = JLD2.load_object(joinpath(iter_path, "eki_file.jld2"))
105112
end
106-
return eki
113+
return ekp
107114
end
108115

109116
"""
110117
calibrate(::Type{AbstractBackend}, config::ExperimentConfig; kwargs...)
111118
calibrate(::Type{AbstractBackend}, experiment_dir; kwargs...)
119+
calibrate(::Type{AbstractBackend}, ekp::EnsembleKalmanProcess, experiment_dir; kwargs...)
112120
113121
Run a full calibration, scheduling the forward model runs on Caltech's HPC cluster.
114122
@@ -160,13 +168,18 @@ function calibrate(
160168
),
161169
verbose = false,
162170
reruns = 1,
171+
ekp = nothing,
163172
hpc_kwargs,
164173
ekp_kwargs...,
165174
)
166-
(; n_iterations, output_dir, ensemble_size) = config
175+
(; n_iterations, output_dir, prior, ensemble_size) = config
167176
@info "Initializing calibration" n_iterations ensemble_size output_dir
168177

169-
eki = initialize(config; ekp_kwargs...)
178+
ekp = if ekp isa EKP.EnsembleKalmanProcess
179+
initialize(ekp, prior, output_dir)
180+
else
181+
initialize(config; ekp_kwargs...)
182+
end
170183
module_load_str = module_load_string(b)
171184
for i in 0:(n_iterations - 1)
172185
@info "Iteration $i"
@@ -201,9 +214,9 @@ function calibrate(
201214
terminate = update_ensemble(config, i)
202215
!isnothing(terminate) && break
203216
iter_path = path_to_iteration(output_dir, i + 1)
204-
eki = JLD2.load_object(joinpath(iter_path, "eki_file.jld2"))
217+
ekp = JLD2.load_object(joinpath(iter_path, "eki_file.jld2"))
205218
end
206-
return eki
219+
return ekp
207220
end
208221

209222
# Dispatch on backend type to unify `calibrate` for all HPCBackends

src/ekp_interface.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,10 +268,11 @@ function _initialize(
268268
initial_ensemble =
269269
EKP.construct_initial_ensemble(rng_ekp, prior, ensemble_size)
270270

271+
ekp_str_kwargs = Dict([string(k) => v for (k, v) in ekp_kwargs])
271272
eki_constructor =
272273
(args...) -> EKP.EnsembleKalmanProcess(
273274
args...,
274-
Dict(EKP.default_options_dict(EKP.Inversion())..., ekp_kwargs...);
275+
merge(EKP.default_options_dict(EKP.Inversion()), ekp_str_kwargs);
275276
rng = rng_ekp,
276277
)
277278

test/Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
[deps]
22
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
3-
CalibrateEmulateSample = "95e48a1f-0bec-4818-9538-3db4340308e3"
43
ClimaCalibrate = "4347a170-ebd6-470c-89d3-5c705c0cacc2"
54
ClimaParams = "5c42b081-d73a-476f-9059-fd94b934656c"
65
Conda = "8f4d0f93-b110-5947-807f-2305c1781a2d"
@@ -12,3 +11,6 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1211
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
1312
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1413
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
14+
15+
[compat]
16+
EnsembleKalmanProcesses = "2"

test/ekp_interface.jl

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,19 @@ config = CAL.ExperimentConfig(
3131
output_dir,
3232
)
3333

34-
eki = CAL.initialize(config)
34+
35+
user_initial_ensemble =
36+
EKP.construct_initial_ensemble(rng_ekp, prior, ensemble_size)
37+
user_constructed_eki = EKP.EnsembleKalmanProcess(
38+
user_initial_ensemble,
39+
observations,
40+
noise,
41+
EKP.Inversion(),
42+
EKP.default_options_dict(EKP.Inversion());
43+
rng = rng_ekp,
44+
)
45+
46+
eki = CAL.initialize(config; rng_seed)
3547
eki_with_kwargs = CAL.initialize(
3648
config;
3749
scheduler = EKP.MutableScheduler(2),
@@ -46,6 +58,14 @@ eki_with_kwargs = CAL.initialize(
4658
@test eki_with_kwargs.accelerator isa EKP.NesterovAccelerator
4759
end
4860

61+
@testset "Test that a user-constructed EKP obj is same as initialized one" begin
62+
for prop in propertynames(eki)
63+
prop in [:u, :accelerator, :localizer] && continue
64+
@test getproperty(eki, prop) == getproperty(user_constructed_eki, prop)
65+
end
66+
@test eki.u[1].stored_data == user_constructed_eki.u[1].stored_data
67+
end
68+
4969
override_file = joinpath(
5070
config.output_dir,
5171
"iteration_000",
@@ -80,11 +100,10 @@ end
80100
joinpath(output_dir, "iteration_000", "member_001", "parameters.toml")
81101
td = CP.create_toml_dict(FT; override_file)
82102
params = CP.get_parameter_values(td, param_names)
83-
@test params.one == 2.513110562120818
84-
@test params.two == 4.614950047803855
103+
@test params.one == 3.1313341622997677
104+
@test params.two == 5.063035177034372
85105
end
86106

87-
88107
@testset "Environment variables" begin
89108
@test_throws ErrorException(
90109
"Experiment dir not found in environment. Ensure that env variable \"CALIBRATION_EXPERIMENT_DIR\" is set.",

test/pure_julia_e2e.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ ekp = calibrate(JuliaBackend, experiment_config)
7777
parameter_values =
7878
[EKP.get_ϕ_mean(prior, ekp, it) for it in 1:(n_iterations + 1)]
7979
@test parameter_values[1][1] 8.507 rtol = 0.01
80-
@test parameter_values[end][1] 19.0124 rtol = 0.01
80+
@test parameter_values[end][1] 11.852161842745355 rtol = 0.01
8181
end
8282

8383
rm(output_dir; recursive = true)

test/runtests.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ using Test
22

33
include("ekp_interface.jl")
44
include("model_interface.jl")
5-
include("emulate_sample.jl")
5+
# Disabled since we use EKP 2.0 in testing, CES is still incompatible with EKP 2.0
6+
# include("emulate_sample.jl")
67
include("pure_julia_e2e.jl")
78
include("aqua.jl")

0 commit comments

Comments
 (0)