Skip to content

Commit 9a48907

Browse files
authored
Minor fixed before release (#224)
* Quickfix dispatch * Adapt cartpole tests like suggested * Adapt Readme * Update julia version * Add the print
1 parent 8f74b08 commit 9a48907

File tree

4 files changed

+26
-17
lines changed

4 files changed

+26
-17
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ QuadGK = "2.4"
2929
Reexport = "1.0"
3030
StatsBase = "0.32.0, 0.33"
3131
Symbolics = "0.1"
32-
julia = "^1.5.0"
32+
julia = "^1.6.0"
3333

3434
[extras]
3535
DiffEqFlux = "aae7a2af-3d4f-5e19-a356-7da93b79d9d0"

README.md

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,35 +38,45 @@ end
3838

3939
u0 = [1.0;0.0;0.0]
4040
tspan = (0.0,100.0)
41-
dt = 0.005
41+
dt = 0.1
4242
prob = ODEProblem(lorenz,u0,tspan)
4343
sol = solve(prob, Tsit5(), saveat = dt, progress = true)
4444

45-
# Differential data from equations
46-
X = Array(sol)
47-
DX = similar(X)
48-
for (i, xi) in enumerate(eachcol(X))
49-
DX[:,i] = lorenz(xi, [], 0.0)
50-
end
5145

5246
## Start the automatic discovery
53-
ddprob = ContinuousDataDrivenProblem(X, sol.t, DX = DX)
47+
ddprob = ContinuousDataDrivenProblem(sol)
5448

5549
@variables t x(t) y(t) z(t)
5650
u = [x;y;z]
5751
basis = Basis(polynomial_basis(u, 5), u, iv = t)
5852
opt = STLSQ(exp10.(-5:0.1:-1))
5953
ddsol = solve(ddprob, basis, opt, normalize = true)
60-
system = result(ddsol)
54+
print(ddsol, Val{true})
6155
```
6256

6357
```
64-
Model ##Basis#350 with 3 equations
65-
x(t) y(t) z(t)
58+
Explicit Result
59+
Solution with 3 equations and 7 parameters.
60+
Returncode: sucess
61+
Sparsity: 7.0
62+
L2 Norm Error: 26.7343984476783
63+
AICC: 1.0013570199499398
64+
65+
Model ##Basis#366 with 3 equations
66+
States : x(t) y(t) z(t)
6667
Parameters : 7
6768
Independent variable: t
6869
Equations
6970
Differential(t)(x(t)) = p₁*x(t) + p₂*y(t)
7071
Differential(t)(y(t)) = p₃*x(t) + p₄*y(t) + p₅*x(t)*z(t)
7172
Differential(t)(z(t)) = p₇*z(t) + p₆*x(t)*y(t)
73+
74+
Parameters:
75+
p₁ : -10.0
76+
p₂ : 10.0
77+
p₃ : 28.0
78+
p₄ : -1.0
79+
p₅ : -1.0
80+
p₆ : 1.0
81+
p₇ : -2.7
7282
```

src/solution.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,9 @@ function Base.summary(io::IO, r::DataDrivenSolution)
9595
end
9696

9797

98-
function Base.print(io::IO, r::DataDrivenSolution, fullview::DataType = Val{false})
98+
function Base.print(io::IO, r::DataDrivenSolution, fullview::DataType)
9999

100-
fullview == Val{false} && return summary(io, r)
100+
fullview != Val{true} && return summary(io, r)
101101

102102
is_implicit(r) ? println(io,"Implicit Result") : println(io,"Explicit Result")
103103
println(io, "Solution with $(length(r.res.eqs)) equations and $(length(r.ps)) parameters.")

test/sindy/cartpole.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
function cart_pole(u, p, t)
32
du = similar(u)
43
F = -0.2 + 0.5*sin(6*t) # the input
@@ -40,7 +39,7 @@ push!(polys, sin.(u[1]).*u[3:4].^2...)
4039
push!(polys, sin.(u[1]).*cos.(u[1])...)
4140
push!(polys, sin.(u[1]).*cos.(u[1]).*u[3:4]...)
4241
push!(polys, sin.(u[1]).*cos.(u[1]).*u[3:4].^2...)
43-
implicits = [du; du .* cos(u[1]); du .* cos(u[1])^2; polys]
42+
implicits = [du; du[1] .* u; du[2] .* u; du .* cos(u[1]); du .* cos(u[1])^2; polys]
4443
push!(implicits, x...)
4544
push!(implicits, x[1]*cos(u[1]))
4645
push!(implicits, x[1]*sin(u[1]))
@@ -60,4 +59,4 @@ m = metrics(res)
6059

6160
@test m.Sparsity == 10
6261
@test m.AICC < 180.0
63-
@test m.Error < 100.0
62+
@test m.Error < 1.0

0 commit comments

Comments
 (0)