Skip to content

Commit 2c9deaf

Browse files
Merge pull request #1871 from KristofferC/kc/diff_zero
avoid differentiating expressions that are trivially zero
2 parents d0e7dc0 + a53dca8 commit 2c9deaf

1 file changed

Lines changed: 22 additions & 2 deletions

File tree

src/diff.jl

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ function executediff(D::Differential, arg::BasicSymbolic{VartypeT}; simplify=fal
387387
summed_args = SymbolicUtils.ArgsT{VartypeT}()
388388
sizehint!(summed_args, length(inner_args))
389389
# We know `D.x` is in `arg`, so the derivative is not identically zero.
390-
# `arg` cannot be `D.x` since, that would have also early exited.
390+
# `arg` cannot be `D.x` since, that would have also early exited.
391391
for (i, a) in enumerate(inner_args)
392392
der = derivative_idx(arr, i)::Union{Nothing, SymbolicT}
393393
if isequal(a, D.x)
@@ -685,14 +685,34 @@ an array of variable expressions.
685685
686686
All other keyword arguments are forwarded to `expand_derivatives`.
687687
"""
688+
# Check if any variable in `varset` depends on `v`, either directly or
689+
# because it is a function whose arguments contain `v`.
690+
function _depends_on(varset::Set{SymbolicT}, v::SymbolicT)
691+
v in varset && return true
692+
for s in varset
693+
if SymbolicUtils.iscall(s)
694+
for arg in SymbolicUtils.arguments(s)
695+
isequal(arg, v) && return true
696+
end
697+
end
698+
end
699+
return false
700+
end
701+
688702
function jacobian(ops::AbstractVector, vars::AbstractVector{SymbolicT};
689703
simplify=false, scalarize::Union{Val{true}, Val{false}}=Val(true), kwargs...)
690704
if scalarize isa Val{true}
691705
ops = Symbolics.scalarize(ops)
692706
vars = Symbolics.scalarize(vars)
693707
end
708+
# Pre-compute variable sets to skip differentiating trivially zero Jacobian entries
709+
op_varsets = Vector{Set{SymbolicT}}(undef, length(ops))
710+
for i in eachindex(ops)
711+
op_varsets[i] = SymbolicUtils.search_variables(ops[i])
712+
end
694713
result = fill(COMMON_ZERO, length(ops), length(vars))
695714
for i in eachindex(ops), j in eachindex(vars)
715+
_depends_on(op_varsets[i], vars[j]) || continue
696716
result[i, j] = executediff(Differential(vars[j]), ops[i]; simplify, kwargs...)
697717
end
698718
return result
@@ -921,7 +941,7 @@ function hessian(O, vars::AbstractVector; simplify=false, kwargs...)
921941
H
922942
end
923943

924-
hessian(O, vars::Arr; kwargs...) = hessian(O, collect(vars); kwargs...)
944+
hessian(O, vars::Arr; kwargs...) = hessian(O, collect(vars); kwargs...)
925945

926946
isidx(x) = unwrap_const(x) isa TermCombination
927947

0 commit comments

Comments
 (0)