Skip to content

Commit d5f3dc1

Browse files
Merge pull request #3535 from vyudu/sdejac
feat: compute jac_prototype for SDEFunction
2 parents cd558a2 + 66f5d59 commit d5f3dc1

File tree

7 files changed

+127
-26
lines changed

7 files changed

+127
-26
lines changed

Diff for: src/ModelingToolkit.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ export structural_simplify, expand_connections, linearize, linearization_functio
303303
LinearizationProblem
304304
export solve
305305

306-
export calculate_jacobian, generate_jacobian, generate_function, generate_custom_function
306+
export calculate_jacobian, generate_jacobian, generate_function, generate_custom_function, generate_W
307307
export calculate_control_jacobian, generate_control_jacobian
308308
export calculate_tgrad, generate_tgrad
309309
export calculate_gradient, generate_gradient

Diff for: src/linearization.jl

-1
Original file line numberDiff line numberDiff line change
@@ -535,7 +535,6 @@ function linearize_symbolic(sys::AbstractSystem, inputs,
535535
if !iszero(Bs)
536536
if !allow_input_derivatives
537537
der_inds = findall(vec(any(!iszero, Bs, dims = 1)))
538-
@show typeof(der_inds)
539538
error("Input derivatives appeared in expressions (-g_z\\g_u != 0), the following inputs appeared differentiated: $(ModelingToolkit.inputs(sys)[der_inds]). Call `linearize_symbolic` with keyword argument `allow_input_derivatives = true` to allow this and have the returned `B` matrix be of double width ($(2nu)), where the last $nu inputs are the derivatives of the first $nu inputs.")
540539
end
541540
B = [B [zeros(nx, nu); Bs]]

Diff for: src/structural_transformation/symbolics_tearing.jl

+4
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,10 @@ function generate_derivative_variables!(
403403
v_t = add_dd_variable!(structure, fullvars, x_t, dv)
404404
# Add `D(x) - x_t ~ 0` to the graph
405405
dummy_eq = add_dd_equation!(structure, neweqs, 0 ~ dx - x_t, dv, v_t)
406+
# Update graph to say, all the equations featuring D(x) also feature x_t
407+
for e in 𝑑neighbors(graph, dv)
408+
add_edge!(graph, e, v_t)
409+
end
406410

407411
# Update matching
408412
push!(var_eq_matching, unassigned)

Diff for: src/systems/diffeqs/abstractodesystem.jl

+51-9
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,15 @@ function calculate_jacobian(sys::AbstractODESystem;
7373

7474
if sparse
7575
jac = sparsejacobian(rhs, dvs, simplify = simplify)
76+
W_s = W_sparsity(sys)
77+
(Is, Js, Vs) = findnz(W_s)
78+
# Add nonzeros of W as non-structural zeros of the Jacobian (to ensure equal results for oop and iip Jacobian.)
79+
for (i, j) in zip(Is, Js)
80+
iszero(jac[i, j]) && begin
81+
jac[i, j] = 1
82+
jac[i, j] = 0
83+
end
84+
end
7685
else
7786
jac = jacobian(rhs, dvs, simplify = simplify)
7887
end
@@ -126,6 +135,35 @@ function generate_jacobian(sys::AbstractODESystem, dvs = unknowns(sys),
126135
dvs,
127136
p...,
128137
get_iv(sys);
138+
wrap_code = sparse ? assert_jac_length_header(sys) : (identity, identity),
139+
kwargs...)
140+
end
141+
142+
function assert_jac_length_header(sys)
143+
W = W_sparsity(sys)
144+
identity, expr -> Func([expr.args...], [], LiteralExpr(quote
145+
@assert $(findnz)($(expr.args[1]))[1:2] == $(findnz)($W)[1:2]
146+
$(expr.body)
147+
end))
148+
end
149+
150+
function generate_W(sys::AbstractODESystem, γ = 1., dvs = unknowns(sys),
151+
ps = parameters(sys; initial_parameters = true);
152+
simplify = false, sparse = false, kwargs...)
153+
@variables ˍ₋gamma
154+
M = calculate_massmatrix(sys; simplify)
155+
sparse && (M = SparseArrays.sparse(M))
156+
J = calculate_jacobian(sys; simplify, sparse, dvs)
157+
W = ˍ₋gamma*M + J
158+
159+
p = reorder_parameters(sys, ps)
160+
return build_function_wrapper(sys, W,
161+
dvs,
162+
p...,
163+
ˍ₋gamma,
164+
get_iv(sys);
165+
wrap_code = sparse ? assert_jac_length_header(sys) : (identity, identity),
166+
p_end = 1 + length(p),
129167
kwargs...)
130168
end
131169

@@ -264,6 +302,14 @@ function jacobian_dae_sparsity(sys::AbstractODESystem)
264302
J1 + J2
265303
end
266304

305+
function W_sparsity(sys::AbstractODESystem)
306+
jac_sparsity = jacobian_sparsity(sys)
307+
(n, n) = size(jac_sparsity)
308+
M = calculate_massmatrix(sys)
309+
M_sparsity = M isa UniformScaling ? sparse(I(n)) : SparseMatrixCSC{Bool, Int64}((!iszero).(M))
310+
jac_sparsity .| M_sparsity
311+
end
312+
267313
function isautonomous(sys::AbstractODESystem)
268314
tgrad = calculate_tgrad(sys; simplify = true)
269315
all(iszero, tgrad)
@@ -368,15 +414,11 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
368414
observedfun = ObservedFunctionCache(
369415
sys; steady_state, eval_expression, eval_module, checkbounds, cse)
370416

371-
jac_prototype = if sparse
417+
if sparse
372418
uElType = u0 === nothing ? Float64 : eltype(u0)
373-
if jac
374-
similar(calculate_jacobian(sys, sparse = sparse), uElType)
375-
else
376-
similar(jacobian_sparsity(sys), uElType)
377-
end
419+
W_prototype = similar(W_sparsity(sys), uElType)
378420
else
379-
nothing
421+
W_prototype = nothing
380422
end
381423

382424
@set! sys.split_idxs = split_idxs
@@ -386,9 +428,9 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
386428
jac = _jac === nothing ? nothing : _jac,
387429
tgrad = _tgrad === nothing ? nothing : _tgrad,
388430
mass_matrix = _M,
389-
jac_prototype = jac_prototype,
431+
jac_prototype = W_prototype,
390432
observed = observedfun,
391-
sparsity = sparsity ? jacobian_sparsity(sys) : nothing,
433+
sparsity = sparsity ? W_sparsity(sys) : nothing,
392434
analytic = analytic,
393435
initialization_data)
394436
end

Diff for: src/systems/diffeqs/sdesystem.jl

+30-8
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ struct SDESystem <: AbstractODESystem
164164
"""
165165
is_dde::Bool
166166
isscheduled::Bool
167+
tearing_state::Any
167168

168169
function SDESystem(tag, deqs, neqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed,
169170
tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults,
@@ -173,7 +174,8 @@ struct SDESystem <: AbstractODESystem
173174
metadata = nothing, gui_metadata = nothing, namespacing = true,
174175
complete = false, index_cache = nothing, parent = nothing, is_scalar_noise = false,
175176
is_dde = false,
176-
isscheduled = false;
177+
isscheduled = false,
178+
tearing_state = nothing;
177179
checks::Union{Bool, Int} = true)
178180
if checks == true || (checks & CheckComponents) > 0
179181
check_independent_variables([iv])
@@ -198,7 +200,7 @@ struct SDESystem <: AbstractODESystem
198200
ctrl_jac, Wfact, Wfact_t, name, description, systems,
199201
defaults, guesses, initializesystem, initialization_eqs, connector_type, cevents,
200202
devents, parameter_dependencies, assertions, metadata, gui_metadata, namespacing,
201-
complete, index_cache, parent, is_scalar_noise, is_dde, isscheduled)
203+
complete, index_cache, parent, is_scalar_noise, is_dde, isscheduled, tearing_state)
202204
end
203205
end
204206

@@ -593,6 +595,7 @@ function DiffEqBase.SDEFunction{iip, specialize}(sys::SDESystem, dvs = unknowns(
593595
u0 = nothing;
594596
version = nothing, tgrad = false, sparse = false,
595597
jac = false, Wfact = false, eval_expression = false,
598+
sparsity = false, analytic = nothing,
596599
eval_module = @__MODULE__,
597600
checkbounds = false, initialization_data = nothing,
598601
cse = true, kwargs...) where {iip, specialize}
@@ -642,6 +645,13 @@ function DiffEqBase.SDEFunction{iip, specialize}(sys::SDESystem, dvs = unknowns(
642645
end
643646

644647
M = calculate_massmatrix(sys)
648+
if sparse
649+
uElType = u0 === nothing ? Float64 : eltype(u0)
650+
W_prototype = similar(W_sparsity(sys), uElType)
651+
else
652+
W_prototype = nothing
653+
end
654+
645655
_M = (u0 === nothing || M == I) ? M : ArrayInterface.restructure(u0 .* u0', M)
646656

647657
observedfun = ObservedFunctionCache(
@@ -651,10 +661,14 @@ function DiffEqBase.SDEFunction{iip, specialize}(sys::SDESystem, dvs = unknowns(
651661
sys = sys,
652662
jac = _jac === nothing ? nothing : _jac,
653663
tgrad = _tgrad === nothing ? nothing : _tgrad,
664+
mass_matrix = _M,
665+
jac_prototype = W_prototype,
666+
observed = observedfun,
667+
sparsity = sparsity ? W_sparsity(sys) : nothing,
668+
analytic = analytic,
654669
Wfact = _Wfact === nothing ? nothing : _Wfact,
655670
Wfact_t = _Wfact_t === nothing ? nothing : _Wfact_t,
656-
mass_matrix = _M, initialization_data,
657-
observed = observedfun)
671+
initialization_data)
658672
end
659673

660674
"""
@@ -724,6 +738,16 @@ function SDEFunctionExpr{iip}(sys::SDESystem, dvs = unknowns(sys),
724738
_jac = :nothing
725739
end
726740

741+
M = calculate_massmatrix(sys)
742+
_M = (u0 === nothing || M == I) ? M : ArrayInterface.restructure(u0 .* u0', M)
743+
744+
if sparse
745+
uElType = u0 === nothing ? Float64 : eltype(u0)
746+
W_prototype = similar(W_sparsity(sys), uElType)
747+
else
748+
W_prototype = nothing
749+
end
750+
727751
if Wfact
728752
tmp_Wfact, tmp_Wfact_t = generate_factorized_W(
729753
sys, dvs, ps; expression = Val{true},
@@ -734,20 +758,18 @@ function SDEFunctionExpr{iip}(sys::SDESystem, dvs = unknowns(sys),
734758
_Wfact, _Wfact_t = :nothing, :nothing
735759
end
736760

737-
M = calculate_massmatrix(sys)
738-
739-
_M = (u0 === nothing || M == I) ? M : ArrayInterface.restructure(u0 .* u0', M)
740-
741761
ex = quote
742762
f = $f
743763
g = $g
744764
tgrad = $_tgrad
745765
jac = $_jac
766+
W_prototype = $W_prototype
746767
Wfact = $_Wfact
747768
Wfact_t = $_Wfact_t
748769
M = $_M
749770
SDEFunction{$iip}(f, g,
750771
jac = jac,
772+
jac_prototype = W_prototype,
751773
tgrad = tgrad,
752774
Wfact = Wfact,
753775
Wfact_t = Wfact_t,

Diff for: src/systems/systems.jl

+3-1
Original file line numberDiff line numberDiff line change
@@ -151,11 +151,13 @@ function __structural_simplify(sys::AbstractSystem, io = nothing; simplify = fal
151151
end
152152

153153
noise_eqs = StructuralTransformations.tearing_substitute_expr(ode_sys, noise_eqs)
154-
return SDESystem(Vector{Equation}(full_equations(ode_sys)), noise_eqs,
154+
ssys = SDESystem(Vector{Equation}(full_equations(ode_sys)), noise_eqs,
155155
get_iv(ode_sys), unknowns(ode_sys), parameters(ode_sys);
156156
name = nameof(ode_sys), is_scalar_noise, observed = observed(ode_sys), defaults = defaults(sys),
157157
parameter_dependencies = parameter_dependencies(sys), assertions = assertions(sys),
158158
guesses = guesses(sys), initialization_eqs = initialization_equations(sys))
159+
@set! ssys.tearing_state = get_tearing_state(ode_sys)
160+
return ssys
159161
end
160162
end
161163

Diff for: test/jacobiansparsity.jl

+38-6
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
1-
using OrdinaryDiffEq, ModelingToolkit, Test, SparseArrays
1+
using ModelingToolkit, SparseArrays, OrdinaryDiffEq
22

33
N = 3
44
xyd_brusselator = range(0, stop = 1, length = N)
55
brusselator_f(x, y, t) = (((x - 0.3)^2 + (y - 0.6)^2) <= 0.1^2) * (t >= 1.1) * 5.0
6-
limit(a, N) = ModelingToolkit.ifelse(a == N + 1, 1, ModelingToolkit.ifelse(a == 0, N, a))
6+
lim(a, N) = ModelingToolkit.ifelse(a == N + 1, 1, ModelingToolkit.ifelse(a == 0, N, a))
77
function brusselator_2d_loop(du, u, p, t)
88
A, B, alpha, dx = p
99
alpha = alpha / dx^2
1010
@inbounds for I in CartesianIndices((N, N))
1111
i, j = Tuple(I)
1212
x, y = xyd_brusselator[I[1]], xyd_brusselator[I[2]]
13-
ip1, im1, jp1, jm1 = limit(i + 1, N), limit(i - 1, N), limit(j + 1, N),
14-
limit(j - 1, N)
13+
ip1, im1, jp1, jm1 = lim(i + 1, N), lim(i - 1, N), lim(j + 1, N),
14+
lim(j - 1, N)
1515
du[i, j, 1] = alpha * (u[im1, j, 1] + u[ip1, j, 1] + u[i, jp1, 1] + u[i, jm1, 1] -
1616
4u[i, j, 1]) +
1717
B + u[i, j, 1]^2 * u[i, j, 2] - (A + 1) * u[i, j, 1] +
@@ -51,7 +51,7 @@ JP = prob.f.jac_prototype
5151

5252
# test sparse jacobian
5353
prob = ODEProblem(sys, u0, (0, 11.5), sparse = true, jac = true)
54-
@test_nowarn solve(prob, Rosenbrock23())
54+
#@test_nowarn solve(prob, Rosenbrock23())
5555
@test findnz(calculate_jacobian(sys, sparse = true))[1:2] ==
5656
findnz(prob.f.jac_prototype)[1:2]
5757

@@ -74,11 +74,43 @@ f = DiffEqBase.ODEFunction(sys, u0 = nothing, sparse = true, jac = false)
7474
# test when u0 is not Float64
7575
u0 = similar(init_brusselator_2d(xyd_brusselator), Float32)
7676
prob_ode_brusselator_2d = ODEProblem(brusselator_2d_loop,
77-
u0, (0.0, 11.5), p)
77+
u0, (0.0, 11.5), p)
7878
sys = complete(modelingtoolkitize(prob_ode_brusselator_2d))
7979

8080
prob = ODEProblem(sys, u0, (0, 11.5), sparse = true, jac = false)
8181
@test eltype(prob.f.jac_prototype) == Float32
8282

8383
prob = ODEProblem(sys, u0, (0, 11.5), sparse = true, jac = true)
8484
@test eltype(prob.f.jac_prototype) == Float32
85+
86+
@testset "W matrix sparsity" begin
87+
t = ModelingToolkit.t_nounits
88+
D = ModelingToolkit.D_nounits
89+
@parameters g
90+
@variables x(t) y(t) λ(t)
91+
eqs = [D(D(x)) ~ λ * x
92+
D(D(y)) ~ λ * y - g
93+
x^2 + y^2 ~ 1]
94+
@mtkbuild pend = ODESystem(eqs, t)
95+
96+
u0 = [x => 1, y => 0]
97+
prob = ODEProblem(pend, u0, (0, 11.5), [g => 1], guesses ==> 1], sparse = true, jac = true)
98+
jac, jac! = generate_jacobian(pend; expression = Val{false}, sparse = true)
99+
jac_prototype = ModelingToolkit.jacobian_sparsity(pend)
100+
W_prototype = ModelingToolkit.W_sparsity(pend)
101+
@test nnz(W_prototype) == nnz(jac_prototype) + 2
102+
103+
# jac_prototype should be the same as W_prototype
104+
@test findnz(prob.f.jac_prototype)[1:2] == findnz(W_prototype)[1:2]
105+
106+
u = zeros(5)
107+
p = prob.p
108+
t = 0.0
109+
@test_throws AssertionError jac!(similar(jac_prototype, Float64), u, p, t)
110+
111+
W, W! = generate_W(pend; expression = Val{false}, sparse = true)
112+
γ = .1
113+
M = sparse(calculate_massmatrix(pend))
114+
@test_throws AssertionError W!(similar(jac_prototype, Float64), u, p, γ, t)
115+
@test W!(similar(W_prototype, Float64), u, p, γ, t) == 0.1 * M + jac!(similar(W_prototype, Float64), u, p, t)
116+
end

0 commit comments

Comments
 (0)