Skip to content

Commit 75e4648

Browse files
Merge pull request #1681 from JuliaSymbolics/as/fix-bugs
fix: respect new dependent variable semantics in a few places
2 parents 6473d60 + 9432543 commit 75e4648

File tree

4 files changed

+22
-5
lines changed

4 files changed

+22
-5
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ SymPy = "2.2"
114114
SymPyPythonCall = "0.5"
115115
SymbolicIndexingInterface = "0.3.14"
116116
SymbolicLimits = "0.2.2"
117-
SymbolicUtils = "4.3"
117+
SymbolicUtils = "4.4"
118118
TermInterface = "2"
119119
julia = "1.10"
120120

src/diff.jl

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,24 @@ function substitute_in_deriv(ex, rules; kw...)
256256
substitute(ex, rules; kw..., filterer = symdiff_substitute_filter)
257257
end
258258

259+
function deriv_and_depvar_substitute_filter(ex::SymbolicT)
260+
SymbolicUtils.default_substitute_filter(ex) || @match ex begin
261+
BSImpl.Term(; f) && if f isa Differential end => true
262+
BSImpl.Term(; f) && if f isa SymbolicT end => true
263+
_ => false
264+
end
265+
end
266+
267+
"""
268+
$(TYPEDSIGNATURES)
269+
270+
Identical to `substitute` except it also substitutes inside `Differential` operator
271+
applications and dependent variables.
272+
"""
273+
function substitute_in_deriv_and_depvar(ex, rules; kw...)
274+
substitute(ex, rules; kw..., filterer = deriv_and_depvar_substitute_filter)
275+
end
276+
259277
function chain_diff(D::Differential, arg::BasicSymbolic{VartypeT}, inner_args::SymbolicUtils.ROArgsT{VartypeT}; kw...)
260278
any(isequal(D.x), inner_args) && return D(arg)
261279

@@ -358,12 +376,12 @@ function executediff(D::Differential, arg::BasicSymbolic{VartypeT}; simplify=fal
358376
summed_args = SymbolicUtils.ArgsT{VartypeT}()
359377
inner_function = arguments(arg)[1]
360378
if iscall(a) || isequal(a, D.x)
361-
t1 = substitute_in_deriv(inner_function, Dict(domainvars => a))
379+
t1 = substitute_in_deriv_and_depvar(inner_function, Dict(domainvars => a))
362380
t2 = executediff(D, a; simplify, throw_no_derivative)
363381
push!(summed_args, -t1*t2)
364382
end
365383
if iscall(b) || isequal(b, D.x)
366-
t1 = substitute_in_deriv(inner_function, Dict(domainvars => b))
384+
t1 = substitute_in_deriv_and_depvar(inner_function, Dict(domainvars => b))
367385
t2 = executediff(D, b; simplify, throw_no_derivative)
368386
push!(summed_args, t1*t2)
369387
end

test/runtests.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ if GROUP == "All" || GROUP == "Core"
6161
@safetestset "LogExpFunctions Test" begin include("logexpfunctions.jl") end
6262
@safetestset "Registration without using Test" begin include("registration_without_using.jl") end
6363
@safetestset "Show Test" begin include("show.jl") end
64-
@safetestset "Utility Function Test" begin include("utils.jl") end
6564
@safetestset "RootFinding solver" begin include("solver.jl") end
6665
@safetestset "Function inverses test" begin include("inverse.jl") end
6766
@safetestset "Taylor Series Test" begin include("taylor.jl") end

test/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ end
6262
end
6363

6464
@testset "Issue#1342 substitute working on called symbolics" begin
65-
@variables p(..) x y
65+
@variables p(::Real) x y
6666
arg = unwrap(substitute(p(x), [p => identity]))
6767
@test iscall(arg) && operation(arg) == identity && isequal(only(arguments(arg)), x)
6868
@test unwrap_const(unwrap(substitute(p(x), [p => sqrt, x => 4.0]; fold = Val(true)))) 2.0

0 commit comments

Comments
 (0)