Skip to content

Commit 98e4536

Browse files
committed
move counter and progress to callback
1 parent 7b95e34 commit 98e4536

File tree

2 files changed

+27
-26
lines changed

2 files changed

+27
-26
lines changed

src/fit.jl

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,20 +84,44 @@ function fit(
8484
lbounds=lbounds,
8585
ubounds=ubounds,
8686
scale=scale,
87-
progress=progress,
8887
kwargs... # other arguments to sim
8988
)
9089

90+
parameters_names = first.(parameters_fitted)
91+
prog = ProgressUnknown(; desc ="Fit counter:", spinner=false, enabled=progress!=:silent, showspeed=true)
92+
numiters = 0
93+
estim_best = Inf
94+
function counter_callback(state, estim_obj)
95+
numiters +=1
96+
97+
x_unscaled = unscale_params.(state.u, scale)
98+
if !isnothing(estim_obj) && !isa(estim_obj, ForwardDiff.Dual) && (estim_obj < estim_best)
99+
estim_best = estim_obj
100+
end
101+
102+
values_to_display = [(:ESTIMATOR_BEST, round(estim_best; digits=2))]
103+
if progress == :full && !(eltype(x_unscaled) <: ForwardDiff.Dual)
104+
for i in eachindex(x_unscaled)
105+
push!(values_to_display, (parameters_names[i], round(x_unscaled[i], sigdigits=3)))
106+
end
107+
end
108+
109+
ProgressMeter.update!(prog, numiters, spinner="⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏"; showvalues = values_to_display)
110+
111+
return false
112+
end
113+
91114
optsol = solve(optprob, fit_alg;
92115
reltol=ftol_rel,
93116
abstol=ftol_abs,
94117
maxiters=maxiters,
95-
maxtime=maxtime)
118+
maxtime=maxtime,
119+
callback=counter_callback)
96120

97121
minx = optsol.u
98122
minf = optsol.objective
99123
ret = Symbol(optsol.retcode)
100-
numiters = 0 # TODO callback to save iters
124+
101125
# to create pairs from Float64
102126
parameter_names = _extract_parameter_names(parameters_fitted)
103127
minx_pairs = [key=>value for (key, value) in zip(parameter_names, unscale_params.(minx, scale))]

src/optprob.jl

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,9 @@ function generate_optimization_problem(
5353
lbounds = fill(0.0, length(parameters_fitted)),
5454
ubounds = fill(Inf, length(parameters_fitted)),
5555
scale = fill(:lin, length(parameters_fitted)),
56-
progress::Symbol = :minimal,
5756
kwargs... # other arguments to sim
5857
) where {C<:AbstractScenario, P<:Pair}
5958

60-
# names of parameters used in fitting and saved in parameters_fitted field of solution
61-
parameters_names = first.(parameters_fitted)
62-
6359
selected_scenario_pairs = Pair{Symbol,Scenario}[]
6460
for scenario_pair in scenario_pairs # iterate through scenarios names
6561
if isempty(last(scenario_pair).measurements)
@@ -82,28 +78,9 @@ function generate_optimization_problem(
8278
kwargs...
8379
)
8480

85-
# progress info
86-
prog = ProgressUnknown(; desc ="Fit counter:", spinner=false, enabled=progress!=:silent, showspeed=true)
87-
count = 0
88-
estim_best = Inf
8981
function obj_func(x, hyper_params)
90-
count+=1
91-
# try - catch is a tmp solution for NLopt
9282
x_unscaled = unscale_params.(x, scale)
9383
estim_obj = estim_fun(x_unscaled)
94-
95-
if !isnothing(estim_obj) && !isa(estim_obj, ForwardDiff.Dual) && (estim_obj < estim_best)
96-
estim_best = estim_obj
97-
end
98-
99-
values_to_display = [(:ESTIMATOR_BEST, round(estim_best; digits=2))]
100-
if progress == :full && !(eltype(x_unscaled) <: ForwardDiff.Dual)
101-
for i in 1:length(x)
102-
push!(values_to_display, (parameters_names[i], round(x_unscaled[i], sigdigits=3)))
103-
end
104-
end
105-
106-
ProgressMeter.update!(prog, count, spinner="⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏"; showvalues = values_to_display)
10784
return estim_obj
10885
end
10986

0 commit comments

Comments
 (0)