Skip to content

Commit 12f9a44

Browse files
Merge pull request #469 from SciML/symbolicsv5
Complete Symbolics v5 update
2 parents 9615177 + 3831358 commit 12f9a44

File tree

6 files changed

+23
-20
lines changed

6 files changed

+23
-20
lines changed

.github/workflows/CI.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ jobs:
1010
test:
1111
runs-on: ubuntu-latest
1212
strategy:
13+
fail-fast: false
1314
matrix:
1415
group:
1516
- Core

Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
2020
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
2121
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2222
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
23+
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
24+
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
2325

2426
[compat]
2527
CommonSolve = "0.2"
@@ -35,6 +37,8 @@ RecipesBase = "1"
3537
Reexport = "1.0"
3638
Setfield = "1"
3739
StatsBase = "0.32.0, 0.33"
40+
Symbolics = "5"
41+
SymbolicUtils = "1"
3842
julia = "1.6"
3943

4044
[extras]

lib/DataDrivenLux/src/custom_priors.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,38 +53,38 @@ function Base.summary(io::IO, d::ObservedDistribution{fixed, D, E}) where {fixed
5353
end
5454

5555
get_init(d::ObservedDistribution) = d.latent_scale
56-
get_scale(d::ObservedDistribution) = transform(d.scale_transformation, d.latent_scale)
56+
get_scale(d::ObservedDistribution) = TransformVariables.transform(d.scale_transformation, d.latent_scale)
5757
get_dist(d::ObservedDistribution{<:Any, D}) where {D} = D
5858

5959
Base.show(io::IO, d::ObservedDistribution) = summary(io, d)
6060

6161
function Distributions.logpdf(d::ObservedDistribution{false}, x::X, x̂::Y,
6262
scale::S) where {X, Y, S <: Number}
63-
sum(map(xs -> d.errormodel(get_dist(d), xs..., transform(d.scale_transformation, scale)),
63+
sum(map(xs -> d.errormodel(get_dist(d), xs..., TransformVariables.transform(d.scale_transformation, scale)),
6464
zip(x, x̂)))
6565
end
6666

6767
function Distributions.logpdf(d::ObservedDistribution{true}, x::X, x̂::Y,
6868
scale::S) where {X, Y, S <: Number}
6969
sum(map(xs -> d.errormodel(get_dist(d), xs...,
70-
transform(d.scale_transformation, d.latent_scale)),
70+
TransformVariables.transform(d.scale_transformation, d.latent_scale)),
7171
zip(x, x̂)))
7272
end
7373

7474
function Distributions.logpdf(d::ObservedDistribution{false}, x::X, x̂::Number,
7575
scale::S) where {X, S <: Number}
7676
sum(map(xs -> d.errormodel(get_dist(d), xs, x̂,
77-
transform(d.scale_transformation, scale)), x))
77+
TransformVariables.transform(d.scale_transformation, scale)), x))
7878
end
7979

8080
function Distributions.logpdf(d::ObservedDistribution{true}, x::X, x̂::Number,
8181
scale::S) where {X, S <: Number}
8282
sum(map(xs -> d.errormodel(get_dist(d), xs, x̂,
83-
transform(d.scale_transformation, d.latent_scale)), x))
83+
TransformVariables.transform(d.scale_transformation, d.latent_scale)), x))
8484
end
8585

8686
function transform_scales(d::ObservedDistribution, scale::T) where {T <: Number}
87-
transform(d.scale_transformation, scale)
87+
TransformVariables.transform(d.scale_transformation, scale)
8888
end
8989

9090
"""
@@ -159,7 +159,7 @@ Base.show(io::IO, p::ParameterDistribution) = summary(io, p)
159159

160160
get_init(p::ParameterDistribution) = p.init
161161
function transform_parameter(p::ParameterDistribution, pval::T) where {T <: Number}
162-
transform(p.transformation, pval)
162+
TransformVariables.transform(p.transformation, pval)
163163
end
164164
get_interval(p::ParameterDistribution) = p.interval
165165

lib/DataDrivenSR/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ SymbolicRegression = "8254be44-1295-4e6a-a16d-46603ac705cb"
1111
[compat]
1212
Reexport = "1.2"
1313
DataDrivenDiffEq = "1"
14-
SymbolicRegression = "0.14"
14+
SymbolicRegression = "0.17"
1515
julia = "1.6"
1616

1717
[extras]

src/DataDrivenDiffEq.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,11 @@ using Setfield
1414

1515
@reexport using ModelingToolkit
1616
using ModelingToolkit: AbstractSystem
17-
using ModelingToolkit: value, operation, arguments, istree, get_observed
18-
using ModelingToolkit.Symbolics
19-
using ModelingToolkit.SymbolicUtils
20-
using ModelingToolkit.Symbolics: scalarize, variable
17+
using SymbolicUtils: operation, arguments, istree, issym
18+
using Symbolics
19+
using Symbolics: scalarize, variable, value
2120
@reexport using ModelingToolkit: states, parameters, independent_variable, observed,
22-
controls, get_iv
21+
controls, get_iv, get_observed
2322

2423
using Random
2524
using QuadGK

src/basis/utils.jl

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,7 @@
11
## Create linear independent basis
22
count_operation(x::Number, op::Function, nested::Bool = true) = 0
3-
count_operation(x::Sym, op::Function, nested::Bool = true) = 0
4-
count_operation(x::SymbolicUtils.BasicSymbolic, op::Function, nested::Bool = true) = 0
5-
function count_operation(x::Num, op::Function, nested::Bool = true)
6-
count_operation(value(x), op, nested)
7-
end
8-
9-
function count_operation(x, op::Function, nested::Bool = true)
3+
function count_operation(x::SymbolicUtils.BasicSymbolic, op::Function, nested::Bool = true)
4+
issym(x) && return 0
105
if operation(x) == op
116
if is_unary(op)
127
# Handles sin, cos and stuff
@@ -23,6 +18,10 @@ function count_operation(x, op::Function, nested::Bool = true)
2318
return 0
2419
end
2520

21+
function count_operation(x::Num, op::Function, nested::Bool = true)
22+
count_operation(value(x), op, nested)
23+
end
24+
2625
function count_operation(x, ops::AbstractArray, nested::Bool = true)
2726
return sum([count_operation(x, op, nested) for op in ops])
2827
end

0 commit comments

Comments
 (0)