Skip to content

Commit ff3e074

Browse files
committed
update Lorenz spatial dep example
1 parent 5120a33 commit ff3e074

File tree

3 files changed

+172
-13
lines changed

3 files changed

+172
-13
lines changed

examples/Lorenz/calibrate_spatial_dep.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,12 +169,12 @@ x0 = x_spun_up[:, end] #last element of the run is the initial condition for cr
169169

170170

171171
#Creating sythetic data
172-
T = 50.0
172+
T = 24.0
173173
ny = nx * 2 #number of data points
174174
lorenz_config_settings = LorenzConfig(dt, T)
175175

176176
# construct how we compute Observations
177-
T_start = T - 25.0
177+
T_start = T - 20.0
178178
T_end = T
179179
observation_config = ObservationConfig(T_start, T_end)
180180

@@ -196,9 +196,10 @@ y_ens = hcat(
196196
obs_noise_cov = cov(y_ens, dims = 2) + 1e-2 * I
197197
y = y_ens[:, 1]
198198

199+
#Prior covariance
200+
199201
pl = 4.0
200202
psig = 5.0
201-
#Prior covariance
202203
B = zeros(nx, nx)
203204
for ii in 1:nx
204205
for jj in 1:nx
@@ -207,6 +208,12 @@ for ii in 1:nx
207208
end
208209
B_sqrt = sqrt(B)
209210

211+
#=
212+
psig = 5.0
213+
B = psig^2*I
214+
B_sqrt = sqrt(B)
215+
=#
216+
210217
#Prior mean
211218
mu = 4.0 * ones(nx)
212219

examples/Lorenz/emulate_sample_spatial_dep.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,14 @@ function main()
2828
]
2929

3030
#### CHOOSE YOUR CASE:
31-
mask = [1]# 1:1 # e.g. 1:2 or [2]
31+
mask = [2]# 1:1 # e.g. 1:2 or [2]
3232
for (case) in cases[mask]
3333

3434

3535
println("case: ", case)
3636
min_iter = 1
37-
max_iter = 7 # number of EKP iterations to use data from is at most this
37+
skip_iter = 1
38+
max_iter = 8 # number of EKP iterations to use data from is at most this
3839

3940
####
4041

@@ -71,7 +72,7 @@ function main()
7172

7273
# Emulate-sample settings
7374
# choice of machine-learning tool in the emulation stage
74-
nugget = 1e-3
75+
nugget = 1e-8
7576
if case == "GP"
7677
gppackage = Emulators.GPJL()
7778
pred_type = Emulators.YType()
@@ -106,7 +107,7 @@ function main()
106107
# Get training points from the EKP iteration number in the second input term
107108
N_iter = min(max_iter, length(get_u(ekpobj)) - 1) # number of paired iterations taken from EKP
108109
min_iter = min(max_iter, max(1, min_iter))
109-
input_output_pairs = Utilities.get_training_points(ekpobj, min_iter:(N_iter - 1))
110+
input_output_pairs = Utilities.get_training_points(ekpobj, min_iter:skip_iter:(N_iter - 1))
110111
input_output_pairs_test = Utilities.get_training_points(ekpobj, N_iter:(length(get_u(ekpobj)) - 1)) # "next" iterations
111112
# Save data
112113
@save joinpath(data_save_directory, "input_output_pairs.jld2") input_output_pairs

examples/Lorenz/plot_spatial_dep.jl

Lines changed: 157 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
# some context calues from emulate_sample_spatial_dep.jl
44
cases = ["GP", "RF-scalar"]
5-
case = cases[1]
5+
case = cases[2]
66

77
# load packages
88
# CES
@@ -19,6 +19,9 @@ using CalibrateEmulateSample.EnsembleKalmanProcesses
1919
using CalibrateEmulateSample.EnsembleKalmanProcesses.ParameterDistributions
2020
const EKP = EnsembleKalmanProcesses
2121

22+
cases = ["GP", "RF-scalar"]
23+
case = cases[2]
24+
2225
# load
2326
homedir = pwd()
2427
println(homedir)
@@ -29,6 +32,8 @@ truth_params = load(data_file)["truth_params_constrained"]
2932
final_params = load(data_file)["final_params_constrained"]
3033
posterior = load(data_file)["posterior"]
3134

35+
prior_file = joinpath(data_save_directory, "priors_spatial_dep.jld2")
36+
prior = load(prior_file)["prior"]
3237

3338
ekp_file = joinpath(data_save_directory, "ekp_spatial_dep.jld2")
3439
ekpobj = load(ekp_file)["ekpobj"]
@@ -49,9 +54,9 @@ quantiles = reduce(hcat, [quantile(row, [0.05, 0.5, 0.95]) for row in eachrow(co
4954
gr(size = (2 * 1.6 * 600, 600), guidefontsize = 18, tickfontsize = 16, legendfontsize = 16)
5055
p1 = plot(
5156
range(0, nx - 1, step = 1),
52-
[truth_params final_params],
53-
label = ["solution" "EKI"],
54-
color = [:black :lightgreen],
57+
truth_params,
58+
label = "solution",
59+
color = :black,
5560
linewidth = 4,
5661
xlabel = "Spatial index",
5762
ylabel = "Forcing (input)",
@@ -72,14 +77,77 @@ p2 = plot(
7277
bottom_margin = 15mm,
7378
xticks = (Int.(0:10:ny), [0, 10, 20, 30, (40, 0), 10, 20, 30, 40]),
7479
)
75-
plot!(p2, 1:length(y), get_g_mean_final(ekpobj), label = "mean-final-output", color = :lightgreen, linewidth = 4)
80+
81+
l = @layout [a b]
82+
plt = plot(p1, p2, layout = l)
83+
84+
savefig(plt, figure_save_directory * "data_spatial_dep.png")
85+
savefig(plt, figure_save_directory * "data_spatial_dep.pdf")
86+
87+
p3 = deepcopy(p1)
88+
p4 = deepcopy(p2)
89+
90+
plot!(p1,
91+
range(0, nx - 1, step = 1),
92+
final_params,
93+
label = "mean ensemble input",
94+
color = :lightgreen,
95+
linewidth = 4,
96+
)
97+
plot!(p2, 1:length(y), get_g_mean_final(ekpobj), label = "mean ensemble output", color = :lightgreen, linewidth = 4)
7698

7799
l = @layout [a b]
78100
plt = plot(p1, p2, layout = l)
79101

80102
savefig(plt, figure_save_directory * "solution_spatial_dep_ens$(N_ens).png")
81103
savefig(plt, figure_save_directory * "solution_spatial_dep_ens$(N_ens).pdf")
82104

105+
#
106+
plot!(p3,
107+
range(0, nx - 1, step = 1),
108+
get_ϕ_final(prior, ekpobj)[:,1],
109+
label = "ensemble inputs",
110+
color = :lightgreen,
111+
linewidth = 4,
112+
linealpha=0.1,
113+
)
114+
115+
plot!(p3,
116+
range(0, nx - 1, step = 1),
117+
get_ϕ_final(prior, ekpobj)[:,2:end],
118+
label = "",
119+
color = :lightgreen,
120+
linewidth = 4,
121+
linealpha=0.1,
122+
)
123+
plot!(p3,
124+
range(0, nx - 1, step = 1),
125+
get_ϕ(prior, ekpobj,1)[:,1],
126+
label = "",
127+
color = :lightgreen,
128+
linewidth = 4,
129+
linealpha=0.1,
130+
)
131+
132+
plot!(p3,
133+
range(0, nx - 1, step = 1),
134+
get_ϕ(prior, ekpobj,1)[:,2:end],
135+
label = "",
136+
color = :lightgreen,
137+
linewidth = 4,
138+
linealpha=0.1,
139+
)
140+
141+
plot!(p4, 1:length(y), get_g_final(ekpobj)[:,1], color = :lightgreen, label = "ensemble outputs", linewidth = 4, linealpha=0.1)
142+
plot!(p4, 1:length(y), get_g_final(ekpobj)[:,2:end], color = :lightgreen, label = "", linewidth = 4, linealpha=0.1)
143+
plot!(p4, 1:length(y), get_g(ekpobj,1)[:,1], color = :lightgreen, label = "", linewidth = 4, linealpha=0.1)
144+
plot!(p4, 1:length(y), get_g(ekpobj,1)[:,2:end], color = :lightgreen, label = "", linewidth = 4, linealpha=0.1)
145+
146+
147+
plt = plot(p3, p4, layout = l)
148+
149+
savefig(plt, figure_save_directory * "solution_spatial_dep_full_ens$(N_ens).png")
150+
savefig(plt, figure_save_directory * "solution_spatial_dep_full_ens$(N_ens).pdf")
83151

84152
# plot - UQ results
85153

@@ -107,7 +175,90 @@ plot!(
107175
fillalpha = 0.1,
108176
)
109177

110-
111178
figpath = joinpath(figure_save_directory, "posterior_ribbons_" * case)
112179
savefig(figpath * ".png")
113180
savefig(figpath * ".pdf")
181+
182+
##########################
183+
cases = ["GP", "RF-scalar"]
184+
185+
# load
186+
homedir = pwd()
187+
println(homedir)
188+
data_file_GP = joinpath(data_save_directory, "posterior_$(cases[1]).jld2")
189+
data_file_RF = joinpath(data_save_directory, "posterior_$(cases[2]).jld2")
190+
truth_params_GP = load(data_file_GP)["truth_params_constrained"]
191+
final_params_GP = load(data_file_GP)["final_params_constrained"]
192+
posterior_GP = load(data_file_GP)["posterior"]
193+
truth_params_RF = load(data_file_RF)["truth_params_constrained"]
194+
final_params_RF = load(data_file_RF)["final_params_constrained"]
195+
posterior_RF = load(data_file_RF)["posterior"]
196+
197+
ekp_file = joinpath(data_save_directory, "ekp_spatial_dep.jld2")
198+
ekpobj = load(ekp_file)["ekpobj"]
199+
N_ens = get_N_ens(ekpobj)
200+
nx = length(truth_params)
201+
y = get_obs(ekpobj)
202+
ny = length(y)
203+
# get samples and quantiles from posterior
204+
param_names_GP = get_name(posterior_GP)
205+
posterior_samples_GP = vcat([get_distribution(posterior_GP)[name] for name in get_name(posterior_GP)]...) #samples are columns
206+
constrained_posterior_samples_GP =
207+
mapslices(x -> transform_unconstrained_to_constrained(posterior_GP, x), posterior_samples_GP, dims = 1)
208+
209+
quantiles_GP = reduce(hcat, [quantile(row, [0.05, 0.5, 0.95]) for row in eachrow(constrained_posterior_samples_GP)])' # rows are quantiles for row of posterior samples
210+
211+
param_names_RF = get_name(posterior_RF)
212+
posterior_samples_RF = vcat([get_distribution(posterior_RF)[name] for name in get_name(posterior_RF)]...) #samples are columns
213+
constrained_posterior_samples_RF =
214+
mapslices(x -> transform_unconstrained_to_constrained(posterior_RF, x), posterior_samples_RF, dims = 1)
215+
216+
quantiles_RF = reduce(hcat, [quantile(row, [0.05, 0.5, 0.95]) for row in eachrow(constrained_posterior_samples_RF)])' # rows are quantiles for row of posterior samples
217+
218+
# plot - UQ results - both
219+
220+
gr(size = (1.6 * 600, 600), guidefontsize = 18, tickfontsize = 16, legendfontsize = 16)
221+
p1 = plot(
222+
range(0, nx - 1, step = 1),
223+
[truth_params final_params],
224+
label = ["solution" "EKI-opt"],
225+
color = [:black :lightgreen],
226+
linewidth = 4,
227+
xlabel = "Spatial index",
228+
ylabel = "Forcing (input)",
229+
left_margin = 15mm,
230+
bottom_margin = 15mm,
231+
)
232+
233+
plot!(
234+
p1,
235+
range(0, nx - 1, step = 1),
236+
quantiles_GP[:, 2], # median of all vals
237+
color = :blue,
238+
label = "GP posterior",
239+
linewidth = 4,
240+
ribbon = [quantiles_GP[:, 2] - quantiles_GP[:, 1] quantiles_GP[:, 3] - quantiles_GP[:, 2]],
241+
linealpha=0.5,
242+
fillalpha = 0.1,
243+
)
244+
245+
plot!(
246+
p1,
247+
range(0, nx - 1, step = 1),
248+
quantiles_RF[:, 2], # median of all vals
249+
color = :red,
250+
label = "posterior_RF",
251+
linewidth = 4,
252+
ribbon = [quantiles_RF[:, 2] - quantiles_RF[:, 1] quantiles_RF[:, 3] - quantiles_RF[:, 2]],
253+
linealpha=0.5,
254+
fillalpha = 0.1,
255+
)
256+
257+
258+
figpath = joinpath(figure_save_directory, "posterior_ribbons_both")
259+
savefig(figpath * ".png")
260+
savefig(figpath * ".pdf")
261+
262+
263+
264+

0 commit comments

Comments
 (0)