Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions src/diff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ function executediff(D::Differential, arg::BasicSymbolic{VartypeT}; simplify=fal
summed_args = SymbolicUtils.ArgsT{VartypeT}()
sizehint!(summed_args, length(inner_args))
# We know `D.x` is in `arg`, so the derivative is not identically zero.
# `arg` cannot be `D.x` since, that would have also early exited.
# `arg` cannot be `D.x` since, that would have also early exited.
for (i, a) in enumerate(inner_args)
der = derivative_idx(arr, i)::Union{Nothing, SymbolicT}
if isequal(a, D.x)
Expand Down Expand Up @@ -685,14 +685,34 @@ an array of variable expressions.

All other keyword arguments are forwarded to `expand_derivatives`.
"""
# Check if any variable in `varset` depends on `v`, either directly or
# because it is a function whose arguments contain `v`.
function _depends_on(varset::Set{SymbolicT}, v::SymbolicT)
v in varset && return true
for s in varset
if SymbolicUtils.iscall(s)
for arg in SymbolicUtils.arguments(s)
isequal(arg, v) && return true
end
end
end
return false
end

function jacobian(ops::AbstractVector, vars::AbstractVector{SymbolicT};
simplify=false, scalarize::Union{Val{true}, Val{false}}=Val(true), kwargs...)
if scalarize isa Val{true}
ops = Symbolics.scalarize(ops)
vars = Symbolics.scalarize(vars)
end
# Pre-compute variable sets to skip differentiating trivially zero Jacobian entries
op_varsets = Vector{Set{SymbolicT}}(undef, length(ops))
for i in eachindex(ops)
op_varsets[i] = SymbolicUtils.search_variables(ops[i])
end
result = fill(COMMON_ZERO, length(ops), length(vars))
for i in eachindex(ops), j in eachindex(vars)
_depends_on(op_varsets[i], vars[j]) || continue
result[i, j] = executediff(Differential(vars[j]), ops[i]; simplify, kwargs...)
end
return result
Expand Down Expand Up @@ -921,7 +941,7 @@ function hessian(O, vars::AbstractVector; simplify=false, kwargs...)
H
end

hessian(O, vars::Arr; kwargs...) = hessian(O, collect(vars); kwargs...)
hessian(O, vars::Arr; kwargs...) = hessian(O, collect(vars); kwargs...)

isidx(x) = unwrap_const(x) isa TermCombination

Expand Down
Loading