@@ -37,13 +37,16 @@ function to_options(x::EQSearch)
3737 )
3838end
3939
40+
41+
4042function DiffEqBase. solve (prob:: AbstractDataDrivenProblem , alg:: EQSearch ;
4143 max_iter:: Int = 10 ,
4244 weights = nothing ,
4345 numprocs = nothing , procs = nothing ,
4446 multithreading = false ,
4547 runtests:: Bool = true ,
46- eval_expression = false
48+ eval_expression = false ,
49+ kwargs...
4750 )
4851
4952 opt = to_options (alg)
@@ -60,29 +63,63 @@ function DiffEqBase.solve(prob::AbstractDataDrivenProblem, alg::EQSearch;
6063 hof = SymbolicRegression. EquationSearch (X, Y, niterations = max_iter, weights = weights, options = opt,
6164 numprocs = numprocs, procs = procs, multithreading = multithreading,
6265 runtests = runtests)
63- # Sort the paretofront
64- doms = map (1 : size (Y, 1 )) do i
65- calculateParetoFrontier (X, Y[i, :], hof[i], opt)
66- end
6766
68- build_solution (prob, alg, doms ; eval_expression = eval_expression)
67+ build_solution (prob, alg, hof ; eval_expression = eval_expression)
6968end
7069
70+ function pareto_optimal_equations (hof:: HallOfFame , prob, alg)
71+ return pareto_optimal_equations ([hof], prob, alg)
72+ end
7173
72- function build_solution (prob:: AbstractDataDrivenProblem , alg:: EQSearch , doms; eval_expression = false )
7374
74- opt = to_options (alg)
75+ function pareto_optimal_equations (hof:: Vector{HallOfFame} , prob, alg)
76+
77+ opts = DataDrivenDiffEq. to_options (alg)
78+ y = DataDrivenDiffEq. get_target (prob)
79+ x, _, t, c = DataDrivenDiffEq. get_oop_args (prob)
80+ X = vcat ([x for x in (x, c, permutedims (t)) if ! isempty (x)]. .. )
81+
7582 @variables x[1 : size (prob. X, 1 )] u[1 : size (prob. U,1 )] t
7683 x = Symbolics. scalarize (x)
7784 u = Symbolics. scalarize (u)
78- x_ = [x;u;t]
85+ x_ = Num [x;u;t]
7986
8087 # Build a dict
8188 subs = Dict ([SymbolicUtils. Sym {Number} (Symbol (" x$(i) " )) => x_[i] for i in 1 : size (x_, 1 )]. .. )
82- # Create a variable
83- eqs = vcat (map (x-> node_to_symbolic (x[end ]. tree, opt), doms))
84- eqs = map (x-> substitute (x, subs), eqs)
8589
90+
91+ eqs = map (1 : size (hof, 1 )) do i
92+ @show i
93+ d = calculateParetoFrontier (X, y[i,:], hof[i], opts)
94+ isempty (d) && return Num (0 )
95+ eq_ = node_to_symbolic (last (d). tree, opts)
96+ substitute (eq_, subs)
97+ end
98+
99+ return eqs, x, u, t
100+ end
101+
102+
103+
104+ function build_solution (prob:: AbstractDataDrivenProblem , alg:: EQSearch , hof; eval_expression = false )
105+
106+ # opt = to_options(alg)
107+ #
108+ # @variables x[1:size(prob.X, 1)] u[1:size(prob.U,1)] t
109+ # x = Symbolics.scalarize(x)
110+ # u = Symbolics.scalarize(u)
111+ # x_ = [x;u;t]
112+
113+ # Build a dict
114+ # subs = Dict([SymbolicUtils.Sym{Number}(Symbol("x$(i)")) => x_[i] for i in 1:size(x_, 1)]...)
115+
116+
117+ # Create a variable
118+ # eqs = vcat(map(x->node_to_symbolic(x[end].tree, opt), doms))
119+ # eqs = map(x->substitute(x, subs), eqs)
120+
121+ eqs, x, u, t = pareto_optimal_equations (hof, prob, alg)
122+
86123 lhs, dt = assert_lhs (prob)
87124
88125
@@ -104,10 +141,11 @@ function build_solution(prob::AbstractDataDrivenProblem, alg::EQSearch, doms; ev
104141 Y = res_ (get_oop_args (prob)... )
105142
106143
144+
107145 error = sum (abs2, X- Y, dims = 2 )[:,1 ]
108146 retcode = :converged
109147
110148 return DataDrivenSolution (
111- false , res_, [], retcode, alg, doms , prob, error
149+ false , res_, [], retcode, alg, hof , prob, error
112150 )
113151end
0 commit comments