Skip to content

feat: compute jac_prototype for SDEFunction #3535

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Apr 8, 2025
2 changes: 1 addition & 1 deletion src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ export structural_simplify, expand_connections, linearize, linearization_functio
LinearizationProblem
export solve

export calculate_jacobian, generate_jacobian, generate_function, generate_custom_function
export calculate_jacobian, generate_jacobian, generate_function, generate_custom_function, generate_W
export calculate_control_jacobian, generate_control_jacobian
export calculate_tgrad, generate_tgrad
export calculate_gradient, generate_gradient
Expand Down
1 change: 0 additions & 1 deletion src/linearization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,6 @@ function linearize_symbolic(sys::AbstractSystem, inputs,
if !iszero(Bs)
if !allow_input_derivatives
der_inds = findall(vec(any(!iszero, Bs, dims = 1)))
@show typeof(der_inds)
error("Input derivatives appeared in expressions (-g_z\\g_u != 0), the following inputs appeared differentiated: $(ModelingToolkit.inputs(sys)[der_inds]). Call `linearize_symbolic` with keyword argument `allow_input_derivatives = true` to allow this and have the returned `B` matrix be of double width ($(2nu)), where the last $nu inputs are the derivatives of the first $nu inputs.")
end
B = [B [zeros(nx, nu); Bs]]
Expand Down
4 changes: 4 additions & 0 deletions src/structural_transformation/symbolics_tearing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,10 @@ function generate_derivative_variables!(
v_t = add_dd_variable!(structure, fullvars, x_t, dv)
# Add `D(x) - x_t ~ 0` to the graph
dummy_eq = add_dd_equation!(structure, neweqs, 0 ~ dx - x_t, dv, v_t)
# Update graph to say, all the equations featuring D(x) also feature x_t
for e in 𝑑neighbors(graph, dv)
add_edge!(graph, e, v_t)
end

# Update matching
push!(var_eq_matching, unassigned)
Expand Down
60 changes: 51 additions & 9 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,15 @@ function calculate_jacobian(sys::AbstractODESystem;

if sparse
jac = sparsejacobian(rhs, dvs, simplify = simplify)
W_s = W_sparsity(sys)
(Is, Js, Vs) = findnz(W_s)
# Add nonzeros of W as non-structural zeros of the Jacobian (to ensure equal results for oop and iip Jacobian.)
for (i, j) in zip(Is, Js)
iszero(jac[i, j]) && begin
jac[i, j] = 1
jac[i, j] = 0
end
end
else
jac = jacobian(rhs, dvs, simplify = simplify)
end
Expand Down Expand Up @@ -126,6 +135,35 @@ function generate_jacobian(sys::AbstractODESystem, dvs = unknowns(sys),
dvs,
p...,
get_iv(sys);
wrap_code = sparse ? assert_jac_length_header(sys) : (identity, identity),
kwargs...)
end

function assert_jac_length_header(sys)
W = W_sparsity(sys)
identity, expr -> Func([expr.args...], [], LiteralExpr(quote
@assert $(findnz)($(expr.args[1]))[1:2] == $(findnz)($W)[1:2]
$(expr.body)
end))
end

function generate_W(sys::AbstractODESystem, γ = 1., dvs = unknowns(sys),
ps = parameters(sys; initial_parameters = true);
simplify = false, sparse = false, kwargs...)
@variables ˍ₋gamma
M = calculate_massmatrix(sys; simplify)
sparse && (M = SparseArrays.sparse(M))
J = calculate_jacobian(sys; simplify, sparse, dvs)
W = ˍ₋gamma*M + J

p = reorder_parameters(sys, ps)
return build_function_wrapper(sys, W,
dvs,
p...,
ˍ₋gamma,
get_iv(sys);
wrap_code = sparse ? assert_jac_length_header(sys) : (identity, identity),
p_end = 1 + length(p),
kwargs...)
end

Expand Down Expand Up @@ -264,6 +302,14 @@ function jacobian_dae_sparsity(sys::AbstractODESystem)
J1 + J2
end

function W_sparsity(sys::AbstractODESystem)
jac_sparsity = jacobian_sparsity(sys)
(n, n) = size(jac_sparsity)
M = calculate_massmatrix(sys)
M_sparsity = M isa UniformScaling ? sparse(I(n)) : SparseMatrixCSC{Bool, Int64}((!iszero).(M))
jac_sparsity .| M_sparsity
end

function isautonomous(sys::AbstractODESystem)
tgrad = calculate_tgrad(sys; simplify = true)
all(iszero, tgrad)
Expand Down Expand Up @@ -368,15 +414,11 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
observedfun = ObservedFunctionCache(
sys; steady_state, eval_expression, eval_module, checkbounds, cse)

jac_prototype = if sparse
if sparse
uElType = u0 === nothing ? Float64 : eltype(u0)
if jac
similar(calculate_jacobian(sys, sparse = sparse), uElType)
else
similar(jacobian_sparsity(sys), uElType)
end
W_prototype = similar(W_sparsity(sys), uElType)
else
nothing
W_prototype = nothing
end

@set! sys.split_idxs = split_idxs
Expand All @@ -386,9 +428,9 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
jac = _jac === nothing ? nothing : _jac,
tgrad = _tgrad === nothing ? nothing : _tgrad,
mass_matrix = _M,
jac_prototype = jac_prototype,
jac_prototype = W_prototype,
observed = observedfun,
sparsity = sparsity ? jacobian_sparsity(sys) : nothing,
sparsity = sparsity ? W_sparsity(sys) : nothing,
analytic = analytic,
initialization_data)
end
Expand Down
1 change: 0 additions & 1 deletion src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,6 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
cons = get_constraintsystem(sys)
cons !== nothing && push!(conssystems, cons)
end
@show conssystems
@set! constraintsystem.systems = conssystems
end

Expand Down
38 changes: 30 additions & 8 deletions src/systems/diffeqs/sdesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ struct SDESystem <: AbstractODESystem
"""
is_dde::Bool
isscheduled::Bool
tearing_state::Any

function SDESystem(tag, deqs, neqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed,
tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults,
Expand All @@ -173,7 +174,8 @@ struct SDESystem <: AbstractODESystem
metadata = nothing, gui_metadata = nothing, namespacing = true,
complete = false, index_cache = nothing, parent = nothing, is_scalar_noise = false,
is_dde = false,
isscheduled = false;
isscheduled = false,
tearing_state = nothing;
checks::Union{Bool, Int} = true)
if checks == true || (checks & CheckComponents) > 0
check_independent_variables([iv])
Expand All @@ -198,7 +200,7 @@ struct SDESystem <: AbstractODESystem
ctrl_jac, Wfact, Wfact_t, name, description, systems,
defaults, guesses, initializesystem, initialization_eqs, connector_type, cevents,
devents, parameter_dependencies, assertions, metadata, gui_metadata, namespacing,
complete, index_cache, parent, is_scalar_noise, is_dde, isscheduled)
complete, index_cache, parent, is_scalar_noise, is_dde, isscheduled, tearing_state)
end
end

Expand Down Expand Up @@ -593,6 +595,7 @@ function DiffEqBase.SDEFunction{iip, specialize}(sys::SDESystem, dvs = unknowns(
u0 = nothing;
version = nothing, tgrad = false, sparse = false,
jac = false, Wfact = false, eval_expression = false,
sparsity = false, analytic = nothing,
eval_module = @__MODULE__,
checkbounds = false, initialization_data = nothing,
cse = true, kwargs...) where {iip, specialize}
Expand Down Expand Up @@ -642,6 +645,13 @@ function DiffEqBase.SDEFunction{iip, specialize}(sys::SDESystem, dvs = unknowns(
end

M = calculate_massmatrix(sys)
if sparse
uElType = u0 === nothing ? Float64 : eltype(u0)
W_prototype = similar(W_sparsity(sys), uElType)
else
W_prototype = nothing
end

_M = (u0 === nothing || M == I) ? M : ArrayInterface.restructure(u0 .* u0', M)

observedfun = ObservedFunctionCache(
Expand All @@ -651,10 +661,14 @@ function DiffEqBase.SDEFunction{iip, specialize}(sys::SDESystem, dvs = unknowns(
sys = sys,
jac = _jac === nothing ? nothing : _jac,
tgrad = _tgrad === nothing ? nothing : _tgrad,
mass_matrix = _M,
jac_prototype = W_prototype,
observed = observedfun,
sparsity = sparsity ? W_sparsity(sys) : nothing,
analytic = analytic,
Wfact = _Wfact === nothing ? nothing : _Wfact,
Wfact_t = _Wfact_t === nothing ? nothing : _Wfact_t,
mass_matrix = _M, initialization_data,
observed = observedfun)
initialization_data)
end

"""
Expand Down Expand Up @@ -724,6 +738,16 @@ function SDEFunctionExpr{iip}(sys::SDESystem, dvs = unknowns(sys),
_jac = :nothing
end

M = calculate_massmatrix(sys)
_M = (u0 === nothing || M == I) ? M : ArrayInterface.restructure(u0 .* u0', M)

if sparse
uElType = u0 === nothing ? Float64 : eltype(u0)
W_prototype = similar(W_sparsity(sys), uElType)
else
W_prototype = nothing
end

if Wfact
tmp_Wfact, tmp_Wfact_t = generate_factorized_W(
sys, dvs, ps; expression = Val{true},
Expand All @@ -734,20 +758,18 @@ function SDEFunctionExpr{iip}(sys::SDESystem, dvs = unknowns(sys),
_Wfact, _Wfact_t = :nothing, :nothing
end

M = calculate_massmatrix(sys)

_M = (u0 === nothing || M == I) ? M : ArrayInterface.restructure(u0 .* u0', M)

ex = quote
f = $f
g = $g
tgrad = $_tgrad
jac = $_jac
W_prototype = $W_prototype
Wfact = $_Wfact
Wfact_t = $_Wfact_t
M = $_M
SDEFunction{$iip}(f, g,
jac = jac,
jac_prototype = W_prototype,
tgrad = tgrad,
Wfact = Wfact,
Wfact_t = Wfact_t,
Expand Down
4 changes: 3 additions & 1 deletion src/systems/systems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,13 @@ function __structural_simplify(sys::AbstractSystem, io = nothing; simplify = fal
end

noise_eqs = StructuralTransformations.tearing_substitute_expr(ode_sys, noise_eqs)
return SDESystem(Vector{Equation}(full_equations(ode_sys)), noise_eqs,
ssys = SDESystem(Vector{Equation}(full_equations(ode_sys)), noise_eqs,
get_iv(ode_sys), unknowns(ode_sys), parameters(ode_sys);
name = nameof(ode_sys), is_scalar_noise, observed = observed(ode_sys), defaults = defaults(sys),
parameter_dependencies = parameter_dependencies(sys), assertions = assertions(sys),
guesses = guesses(sys), initialization_eqs = initialization_equations(sys))
@set! ssys.tearing_state = get_tearing_state(ode_sys)
return ssys
end
end

Expand Down
44 changes: 38 additions & 6 deletions test/jacobiansparsity.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
using OrdinaryDiffEq, ModelingToolkit, Test, SparseArrays
using ModelingToolkit, SparseArrays, OrdinaryDiffEq

N = 3
xyd_brusselator = range(0, stop = 1, length = N)
brusselator_f(x, y, t) = (((x - 0.3)^2 + (y - 0.6)^2) <= 0.1^2) * (t >= 1.1) * 5.0
limit(a, N) = ModelingToolkit.ifelse(a == N + 1, 1, ModelingToolkit.ifelse(a == 0, N, a))
lim(a, N) = ModelingToolkit.ifelse(a == N + 1, 1, ModelingToolkit.ifelse(a == 0, N, a))
function brusselator_2d_loop(du, u, p, t)
A, B, alpha, dx = p
alpha = alpha / dx^2
@inbounds for I in CartesianIndices((N, N))
i, j = Tuple(I)
x, y = xyd_brusselator[I[1]], xyd_brusselator[I[2]]
ip1, im1, jp1, jm1 = limit(i + 1, N), limit(i - 1, N), limit(j + 1, N),
limit(j - 1, N)
ip1, im1, jp1, jm1 = lim(i + 1, N), lim(i - 1, N), lim(j + 1, N),
lim(j - 1, N)
du[i, j, 1] = alpha * (u[im1, j, 1] + u[ip1, j, 1] + u[i, jp1, 1] + u[i, jm1, 1] -
4u[i, j, 1]) +
B + u[i, j, 1]^2 * u[i, j, 2] - (A + 1) * u[i, j, 1] +
Expand Down Expand Up @@ -51,7 +51,7 @@ JP = prob.f.jac_prototype

# test sparse jacobian
prob = ODEProblem(sys, u0, (0, 11.5), sparse = true, jac = true)
@test_nowarn solve(prob, Rosenbrock23())
#@test_nowarn solve(prob, Rosenbrock23())
@test findnz(calculate_jacobian(sys, sparse = true))[1:2] ==
findnz(prob.f.jac_prototype)[1:2]

Expand All @@ -74,11 +74,43 @@ f = DiffEqBase.ODEFunction(sys, u0 = nothing, sparse = true, jac = false)
# test when u0 is not Float64
u0 = similar(init_brusselator_2d(xyd_brusselator), Float32)
prob_ode_brusselator_2d = ODEProblem(brusselator_2d_loop,
u0, (0.0, 11.5), p)
u0, (0.0, 11.5), p)
sys = complete(modelingtoolkitize(prob_ode_brusselator_2d))

prob = ODEProblem(sys, u0, (0, 11.5), sparse = true, jac = false)
@test eltype(prob.f.jac_prototype) == Float32

prob = ODEProblem(sys, u0, (0, 11.5), sparse = true, jac = true)
@test eltype(prob.f.jac_prototype) == Float32

@testset "W matrix sparsity" begin
t = ModelingToolkit.t_nounits
D = ModelingToolkit.D_nounits
@parameters g
@variables x(t) y(t) λ(t)
eqs = [D(D(x)) ~ λ * x
D(D(y)) ~ λ * y - g
x^2 + y^2 ~ 1]
@mtkbuild pend = ODESystem(eqs, t)

u0 = [x => 1, y => 0]
prob = ODEProblem(pend, u0, (0, 11.5), [g => 1], guesses = [λ => 1], sparse = true, jac = true)
jac, jac! = generate_jacobian(pend; expression = Val{false}, sparse = true)
jac_prototype = ModelingToolkit.jacobian_sparsity(pend)
W_prototype = ModelingToolkit.W_sparsity(pend)
@test nnz(W_prototype) == nnz(jac_prototype) + 2

# jac_prototype should be the same as W_prototype
@test findnz(prob.f.jac_prototype)[1:2] == findnz(W_prototype)[1:2]

u = zeros(5)
p = prob.p
t = 0.0
@test_throws AssertionError jac!(similar(jac_prototype, Float64), u, p, t)

W, W! = generate_W(pend; expression = Val{false}, sparse = true)
γ = .1
M = sparse(calculate_massmatrix(pend))
@test_throws AssertionError W!(similar(jac_prototype, Float64), u, p, γ, t)
@test W!(similar(W_prototype, Float64), u, p, γ, t) == 0.1 * M + jac!(similar(W_prototype, Float64), u, p, t)
end
2 changes: 1 addition & 1 deletion test/nonlinearsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ eqs = [0 ~ σ * (y - x) * h,
@test eval(toexpr(ns)) == ns
test_nlsys_inference("standard", ns, (x, y, z), (σ, ρ, β))
@test begin
f = eval(generate_function(ns, [x, y, z], [σ, ρ, β])[2])
f = generate_function(ns, [x, y, z], [σ, ρ, β], expression = Val{false})[2]
du = [0.0, 0.0, 0.0]
f(du, [1, 2, 3], [1, 2, 3])
du ≈ [1, -3, -7]
Expand Down
Loading