@@ -44,79 +44,107 @@ $(FIELDS)
4444 eq_options:: SymbolicRegression.Options = SymbolicRegression. Options ()
4545end
4646
47- struct SRResult{H, P, T, TE} <: AbstractDataDrivenResult
47+ struct SRResult{H, P, T} <: AbstractDataDrivenResult
48+ " The resulting basis"
49+ basis:: Basis
50+ " The Hall of Fame"
4851 halloffame:: H
52+ """ The Paretofrontier"""
4953 paretofrontier:: P
50- testerror:: T
51- trainerror:: TE
54+ # StatsBase results
55+ """ Residual sum of squares"""
56+ rss:: T
57+ """ Loglikelihood"""
58+ loglikelihood:: T
59+ """ Nullloglikelihood"""
60+ nullloglikelihood:: T
61+ """ Degrees of freedom"""
62+ dof:: Int
63+ """ Number of observations"""
64+ nobs:: Int
65+ """ Returncode"""
5266 retcode:: DDReturnCode
5367end
5468
55- is_success (k:: SRResult ) = getfield (k, :retcode ) == DDReturnCode (1 )
56- l2error (k:: SRResult ) = is_success (k) ? getfield (k, :testerror ) : Inf
57- function l2error (k:: SRResult{<:Any, <:Any, <:Any, Nothing} )
58- is_success (k) ? getfield (k, :traineerror ) : Inf
69+ function SRResult (prob, hof, paretos)
70+ @unpack basis, problem = prob
71+ bs = convert_to_basis (paretos, prob)
72+ ps = get_parameter_values (bs)
73+ problem = DataDrivenDiffEq. remake_problem (problem, p = ps)
74+ y = DataDrivenDiffEq. get_implicit_data (problem)
75+ rss = sum (abs2, y .- bs (problem))
76+ dof = length (ps)
77+ nobs = prod (size (y))
78+ ll = iszero (rss) ? convert (eltype (rss), Inf ) : - nobs / 2 * log (rss / nobs)
79+ ll0 = - nobs / 2 * log .(sum (abs2, y .- mean (y, dims = 2 )[:, 1 ]) / nobs)
80+ return SRResult (bs, hof, paretos,
81+ rss, ll, ll0, dof, nobs,
82+ DDReturnCode (1 ))
5983end
6084
61- # apply the algorithm on each dataset
62- function (x:: EQSearch )(ps:: InternalDataDrivenProblem{EQSearch} , X, Y)
63- @unpack problem, testdata, options = ps
64- @unpack maxiters, abstol = options
65- @unpack weights, eq_options, numprocs, procs, parallelism, runtests = x
85+ is_success (k:: SRResult ) = getfield (k, :retcode ) == DDReturnCode (1 )
6686
67- hofs = SymbolicRegression. EquationSearch (X, Y,
68- niterations = maxiters,
69- weights = weights,
70- options = eq_options,
71- numprocs = numprocs,
72- procs = procs, parallelism = parallelism,
73- runtests = runtests)
87+ # StatsBase Overload
88+ StatsBase. coef (x:: SRResult ) = getfield (x, :k )
7489
75- # We always want something which is a vector or tuple
76- hofs = ! isa (hofs, AbstractVector) ? [hofs] : hofs
90+ StatsBase. rss (x:: SRResult ) = getfield (x, :rss )
7791
78- # Evaluate over the full training data
79- paretos = map (enumerate (hofs)) do (i, hof)
80- SymbolicRegression. calculate_pareto_frontier (X, Y[i, :], hof, eq_options)
92+ StatsBase. dof (x:: SRResult ) = getfield (x, :dof )
93+
94+ StatsBase. nobs (x:: SRResult ) = getfield (x, :nobs )
95+
96+ StatsBase. loglikelihood (x:: SRResult ) = getfield (x, :loglikelihood )
97+
98+ StatsBase. nullloglikelihood (x:: SRResult ) = getfield (x, :nullloglikelihood )
99+
100+ StatsBase. r2 (x:: SRResult ) = r2 (x, :CoxSnell )
101+
102+ function collect_numerical_parameters (eq, options = DataDrivenCommonOptions ())
103+ ps = Any[]
104+ eqs = map (eq) do eqi
105+ _collect_numerical_parameters! (ps, eqi, options)
81106 end
107+ return eqs, ps
108+ end
82109
83- # Trainingerror
84- trainerror = mean (x -> x[end ]. loss, paretos)
85- # Testerror
86- X̃, Ỹ = testdata
87- if ! isempty (X̃)
88- testerror = mean (map (enumerate (hofs)) do (i, hof)
89- doms = SymbolicRegression. calculate_pareto_frontier (X̃,
90- Ỹ[i, :],
91- hof,
92- eq_options)
93- doms[end ]. loss
94- end )
95- retcode = testerror <= abstol ? DDReturnCode (1 ) : DDReturnCode (5 )
110+ function _collect_numerical_parameters! (ps:: AbstractVector , eq, options)
111+ if Symbolics. istree (eq)
112+ args_ = map (Symbolics. arguments (eq)) do (eqi)
113+ _collect_numerical_parameters! (ps, eqi, options)
114+ end
115+ return Symbolics. operation (eq)(args_... )
116+ elseif isa (eq, Number)
117+ pval = round (eq, options. roundingmode, digits = options. digits)
118+ # We do not collect zeros or ones
119+ iszero (pval) && return zero (eltype (pval))
120+ (abs (pval) ≈ 1 ) & return sign (pval) * one (eltype (pval))
121+ p_ = Symbolics. variable (:p , length (ps) + 1 )
122+ p_ = Symbolics. setdefaultval (p_, pval)
123+ p_ = ModelingToolkit. toparam (p_)
124+ push! (ps, p_)
125+ return p_
96126 else
97- testerror = nothing
98- retcode = trainerror <= abstol ? DDReturnCode (1 ) : DDReturnCode (5 )
127+ return eq
99128 end
100-
101- return SRResult (hofs, paretos, testerror, trainerror, retcode)
102129end
103130
104- function convert_to_basis (res:: SRResult , prob)
105- @unpack paretofrontier = res
131+ function convert_to_basis (paretofrontier, prob)
106132 @unpack alg, basis, problem, options = prob
107133 @unpack eq_options = alg
108134 @unpack maxiters, eval_expresssion, generate_symbolic_parameters, digits, roundingmode = options
109135
110- eqs_ = Num .( map (paretofrontier) do dom
111- node_to_symbolic (dom[end ]. tree, eq_options)
112- end )
136+ eqs_ = map (paretofrontier) do dom
137+ node_to_symbolic (dom[end ]. tree, eq_options)
138+ end
113139
114140 # Substitute with the basis elements
115141 atoms = map (xi -> xi. rhs, equations (basis))
116142
117143 subs = Dict ([SymbolicUtils. Sym {LiteralReal} (Symbol (" x$(i) " )) => x
118144 for (i, x) in enumerate (atoms)]. .. )
119- eqs = map (Base. Fix2 (substitute, subs), eqs_)
145+
146+ eqs, ps = collect_numerical_parameters (eqs_)
147+ eqs = map (Base. Fix2 (substitute, subs), eqs)
120148
121149 # Get the lhs
122150 causality, dt = DataDrivenDiffEq. assert_lhs (problem)
@@ -135,40 +163,61 @@ function convert_to_basis(res::SRResult, prob)
135163 eqs = [phi[i] ~ eq for (i, eq) in enumerate (eqs)]
136164 end
137165
138- ps = parameters (basis)
166+ ps_ = parameters (basis)
139167 @unpack p = problem
140168
141169 p_new = map (eachindex (p)) do i
142- DataDrivenDiffEq. _set_default_val (Num (ps [i]), p[i])
170+ DataDrivenDiffEq. _set_default_val (Num (ps_ [i]), p[i])
143171 end
144172
145173 Basis (eqs, states (basis),
146- parameters = p_new, iv = get_iv (basis),
174+ parameters = [ p_new; ps] , iv = get_iv (basis),
147175 controls = controls (basis), observed = observed (basis),
148176 implicits = implicit_variables (basis),
149177 name = gensym (:Basis ),
150178 eval_expression = eval_expresssion)
151179end
152180
181+ # apply the algorithm on each dataset
182+ function (x:: EQSearch )(ps:: InternalDataDrivenProblem{EQSearch} , X, Y)
183+ @unpack problem, testdata, options = ps
184+ @unpack maxiters, abstol = options
185+ @unpack weights, eq_options, numprocs, procs, parallelism, runtests = x
186+
187+ hofs = SymbolicRegression. EquationSearch (X, Y,
188+ niterations = maxiters,
189+ weights = weights,
190+ options = eq_options,
191+ numprocs = numprocs,
192+ procs = procs, parallelism = parallelism,
193+ runtests = runtests)
194+
195+ # We always want something which is a vector or tuple
196+ hofs = ! isa (hofs, AbstractVector) ? [hofs] : hofs
197+
198+ # Evaluate over the full training data
199+ paretos = map (enumerate (hofs)) do (i, hof)
200+ SymbolicRegression. calculate_pareto_frontier (X, Y[i, :], hof, eq_options)
201+ end
202+
203+ return SRResult (ps, hofs, paretos)
204+ end
205+
153206function CommonSolve. solve! (ps:: InternalDataDrivenProblem{EQSearch} )
154207 @unpack alg, basis, testdata, traindata, kwargs = ps
155208 @unpack weights, numprocs, procs, addprocs_function, parallelism, runtests, eq_options = alg
156209 @unpack traindata, testdata, basis, options = ps
157- @unpack maxiters, eval_expresssion, generate_symbolic_parameters, digits, roundingmode = options
210+ @unpack maxiters, eval_expresssion, generate_symbolic_parameters, digits, roundingmode, selector = options
158211 @unpack problem = ps
159212
160213 results = map (traindata) do (X, Y)
161214 alg (ps, X, Y)
162215 end
163216
164- # Get the best result based on test error, if applicable else use testerror
165- sort! (results, by = l2error)
166- # Convert to basis
167- best_res = first (results)
168-
169- new_basis = convert_to_basis (best_res, ps)
217+ idx = argmin (map (selector, results))
218+ best_res = results[idx]
170219
171- DataDrivenSolution (new_basis , problem, alg, results, ps, best_res. retcode)
220+ DataDrivenSolution (best_res . basis , problem, alg, results, ps, best_res. retcode)
172221end
173222
174223export EQSearch
0 commit comments