From a53dca8226ed60b8b3564cb8ad7a7be57dc80305 Mon Sep 17 00:00:00 2001 From: Kristoffer Carlsson Date: Tue, 19 May 2026 14:09:27 +0200 Subject: [PATCH] avoid differentiating expressions that are trivially zero --- src/diff.jl | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/src/diff.jl b/src/diff.jl index f7ef9cde7..beab93a8b 100644 --- a/src/diff.jl +++ b/src/diff.jl @@ -387,7 +387,7 @@ function executediff(D::Differential, arg::BasicSymbolic{VartypeT}; simplify=fal summed_args = SymbolicUtils.ArgsT{VartypeT}() sizehint!(summed_args, length(inner_args)) # We know `D.x` is in `arg`, so the derivative is not identically zero. - # `arg` cannot be `D.x` since, that would have also early exited. + # `arg` cannot be `D.x` since, that would have also early exited. for (i, a) in enumerate(inner_args) der = derivative_idx(arr, i)::Union{Nothing, SymbolicT} if isequal(a, D.x) @@ -685,14 +685,34 @@ an array of variable expressions. All other keyword arguments are forwarded to `expand_derivatives`. """ +# Check if any variable in `varset` depends on `v`, either directly or +# because it is a function whose arguments contain `v`. +function _depends_on(varset::Set{SymbolicT}, v::SymbolicT) + v in varset && return true + for s in varset + if SymbolicUtils.iscall(s) + for arg in SymbolicUtils.arguments(s) + isequal(arg, v) && return true + end + end + end + return false +end + function jacobian(ops::AbstractVector, vars::AbstractVector{SymbolicT}; simplify=false, scalarize::Union{Val{true}, Val{false}}=Val(true), kwargs...) if scalarize isa Val{true} ops = Symbolics.scalarize(ops) vars = Symbolics.scalarize(vars) end + # Pre-compute variable sets to skip differentiating trivially zero Jacobian entries + op_varsets = Vector{Set{SymbolicT}}(undef, length(ops)) + for i in eachindex(ops) + op_varsets[i] = SymbolicUtils.search_variables(ops[i]) + end result = fill(COMMON_ZERO, length(ops), length(vars)) for i in eachindex(ops), j in eachindex(vars) + _depends_on(op_varsets[i], vars[j]) || continue result[i, j] = executediff(Differential(vars[j]), ops[i]; simplify, kwargs...) end return result @@ -921,7 +941,7 @@ function hessian(O, vars::AbstractVector; simplify=false, kwargs...) H end -hessian(O, vars::Arr; kwargs...) = hessian(O, collect(vars); kwargs...) +hessian(O, vars::Arr; kwargs...) = hessian(O, collect(vars); kwargs...) isidx(x) = unwrap_const(x) isa TermCombination