11using Distributed
22
3+ import EnsembleKalmanProcesses as EKP
4+
35export get_backend, calibrate, model_run
46
57abstract type AbstractBackend end
@@ -82,10 +84,15 @@ function calibrate(
8284 :: Type{JuliaBackend} ,
8385 config:: ExperimentConfig ;
8486 reruns = 0 ,
87+ ekp = nothing ,
8588 ekp_kwargs... ,
8689)
8790 (; n_iterations, output_dir, ensemble_size) = config
88- eki = initialize(config; ekp_kwargs... )
91+ ekp = if ekp isa EKP. EnsembleKalmanProcess
92+ initialize(ekp, prior, output_dir)
93+ else
94+ initialize(config; ekp_kwargs... )
95+ end
8996 on_error(e:: InterruptException ) = rethrow(e)
9097 on_error(e) =
9198 @error " Single ensemble member has errored. See stacktrace" exception =
@@ -101,14 +108,15 @@ function calibrate(
101108 terminate = update_ensemble(config, i)
102109 ! isnothing(terminate) && break
103110 iter_path = path_to_iteration(output_dir, i + 1 )
104- eki = JLD2. load_object(joinpath(iter_path, " eki_file.jld2" ))
111+ ekp = JLD2. load_object(joinpath(iter_path, " eki_file.jld2" ))
105112 end
106- return eki
113+ return ekp
107114end
108115
109116"""
110117 calibrate(::Type{AbstractBackend}, config::ExperimentConfig; kwargs...)
111118 calibrate(::Type{AbstractBackend}, experiment_dir; kwargs...)
119+ calibrate(::Type{AbstractBackend}, ekp::EnsembleKalmanProcess, experiment_dir; kwargs...)
112120
113121Run a full calibration, scheduling the forward model runs on Caltech's HPC cluster.
114122
@@ -160,13 +168,18 @@ function calibrate(
160168 ),
161169 verbose = false ,
162170 reruns = 1 ,
171+ ekp = nothing ,
163172 hpc_kwargs,
164173 ekp_kwargs... ,
165174)
166- (; n_iterations, output_dir, ensemble_size) = config
175+ (; n_iterations, output_dir, prior, ensemble_size) = config
167176 @info " Initializing calibration" n_iterations ensemble_size output_dir
168177
169- eki = initialize(config; ekp_kwargs... )
178+ ekp = if ekp isa EKP. EnsembleKalmanProcess
179+ initialize(ekp, prior, output_dir)
180+ else
181+ initialize(config; ekp_kwargs... )
182+ end
170183 module_load_str = module_load_string(b)
171184 for i in 0 : (n_iterations - 1 )
172185 @info " Iteration $i "
@@ -201,9 +214,9 @@ function calibrate(
201214 terminate = update_ensemble(config, i)
202215 ! isnothing(terminate) && break
203216 iter_path = path_to_iteration(output_dir, i + 1 )
204- eki = JLD2. load_object(joinpath(iter_path, " eki_file.jld2" ))
217+ ekp = JLD2. load_object(joinpath(iter_path, " eki_file.jld2" ))
205218 end
206- return eki
219+ return ekp
207220end
208221
209222# Dispatch on backend type to unify `calibrate` for all HPCBackends
0 commit comments