Skip to content

Make get_differential_vars type stable #2698

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions docs/src/massmatrixdae/BDF.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ CollapsedDocStrings = true
Multistep BDF methods, good for large stiff systems.

```julia
using LinearAlgebra: Diagonal
function rober(du, u, p, t)
y₁, y₂, y₃ = u
k₁, k₂, k₃ = p
Expand All @@ -15,9 +16,7 @@ function rober(du, u, p, t)
du[3] = y₁ + y₂ + y₃ - 1
nothing
end
M = [1.0 0 0
0 1.0 0
0 0 0]
M = Diagonal([1.0, 1.0, 0])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AayushSabharwal does MTK use Diagonal?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it generates a Matrix{Float64}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's update that

f = ODEFunction(rober, mass_matrix = M)
prob_mm = ODEProblem(f, [1.0, 0.0, 0.0], (0.0, 1e5), (0.04, 3e7, 1e4))
sol = solve(prob_mm, FBDF(), reltol = 1e-8, abstol = 1e-8)
Expand Down
5 changes: 2 additions & 3 deletions docs/src/massmatrixdae/Rosenbrock.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ For larger systems look at multistep methods.
## Example usage

```julia
using LinearAlgebra: Diagonal
function rober(du, u, p, t)
y₁, y₂, y₃ = u
k₁, k₂, k₃ = p
Expand All @@ -28,9 +29,7 @@ function rober(du, u, p, t)
du[3] = y₁ + y₂ + y₃ - 1
nothing
end
M = [1.0 0 0
0 1.0 0
0 0 0]
M = Diagonal([1.0, 1.0, 0])
f = ODEFunction(rober, mass_matrix = M)
prob_mm = ODEProblem(f, [1.0, 0.0, 0.0], (0.0, 1e5), (0.04, 3e7, 1e4))
sol = solve(prob_mm, Rodas5(), reltol = 1e-8, abstol = 1e-8)
Expand Down
25 changes: 18 additions & 7 deletions lib/OrdinaryDiffEqBDF/test/dae_ad_tests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
using OrdinaryDiffEqBDF, LinearAlgebra, ForwardDiff, Test
using OrdinaryDiffEqNonlinearSolve: BrownFullBasicInit, ShampineCollocationInit
using ADTypes: AutoForwardDiff, AutoFiniteDiff

afd_cs3 = AutoForwardDiff(chunksize=3)

function f(out, du, u, p, t)
out[1] = -p[1] * u[1] + p[3] * u[2] * u[3] - du[1]
Expand All @@ -16,22 +19,30 @@ u₀ = [1.0, 0, 0]
du₀ = [-0.04, 0.04, 0.0]
tspan = (0.0, 100000.0)
differential_vars = [true, true, false]
M = Diagonal([1.0, 1.0, 0.0])
prob = DAEProblem(f, du₀, u₀, tspan, p, differential_vars = differential_vars)
prob_oop = DAEProblem{false}(f, du₀, u₀, tspan, p, differential_vars = differential_vars)
sol1 = solve(prob, DFBDF(), dt = 1e-5, abstol = 1e-8, reltol = 1e-8)
sol2 = solve(prob_oop, DFBDF(), dt = 1e-5, abstol = 1e-8, reltol = 1e-8)
f_mm = ODEFunction{true, SciMLBase.AutoSpecialize}(f, mass_matrix = M)
prob_mm = ODEProblem(f_mm, u₀, tspan, p)
@test_broken sol1 = @inferred solve(prob, DFBDF(autodiff=afd_cs3), dt = 1e-5, abstol = 1e-8, reltol = 1e-8)
@test_broken sol2 = @inferred solve(prob_oop, DFBDF(autodiff=afd_cs3), dt = 1e-5, abstol = 1e-8, reltol = 1e-8)
@test_broken sol3 = @inferred solve(prob_mm, FBDF(autodiff=afd_cs3), dt = 1e-5, abstol = 1e-8, reltol = 1e-8)

# These tests flex differentiation of the solver and through the initialization
# To only test the solver part and isolate potential issues, set the initialization to consistent
@testset "Inplace: $(isinplace(_prob)), DAEProblem: $(_prob isa DAEProblem), BrownBasic: $(initalg isa BrownFullBasicInit), Autodiff: $autodiff" for _prob in [
prob, prob_oop],
initalg in [BrownFullBasicInit(), ShampineCollocationInit()], autodiff in [true, false]
@testset "Inplace: $(isinplace(_prob)), DAEProblem: $(_prob isa DAEProblem), BrownBasic: $(initalg isa BrownFullBasicInit), Autodiff: $autodiff" for _prob in [
prob, prob_oop, prob_mm],
initalg in [BrownFullBasicInit(), ShampineCollocationInit()], autodiff in [afd_cs3, AutoFiniteDiff()]

alg = DFBDF(; autodiff)
alg = (_prob isa DAEProblem) ? DFBDF(; autodiff) : FBDF(; autodiff)
function f(p)
sol = solve(remake(_prob, p = p), alg, abstol = 1e-14,
reltol = 1e-14, initializealg = initalg)
sum(sol)
end
@test ForwardDiff.gradient(f, [0.04, 3e7, 1e4])≈[0, 0, 0] atol=1e-8
if _prob isa DAEProblem
@test ForwardDiff.gradient(f, [0.04, 3e7, 1e4])≈[0, 0, 0] atol=1e-8
else
@test_broken ForwardDiff.gradient(f, [0.04, 3e7, 1e4])≈[0, 0, 0] atol=1e-8
end
end
9 changes: 6 additions & 3 deletions lib/OrdinaryDiffEqCore/src/misc_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,18 +114,21 @@ are differential variables. Returns `DifferentialVarsUndefined` if it cannot
be determined (i.e. the mass matrix is not diagonal).
"""
function get_differential_vars(f, u)
differential_vars = nothing
if hasproperty(f, :mass_matrix)
mm = f.mass_matrix
mm = mm isa MatrixOperator ? mm.A : mm

if mm isa UniformScaling || all(!iszero, mm)
if mm isa UniformScaling
return nothing
elseif all(!iszero, mm)
return trues(size(mm, 1))
elseif !(mm isa SciMLOperators.AbstractSciMLOperator) && isdiag(mm)
differential_vars = reshape(diag(mm) .!= 0, size(u))
return reshape(diag(mm) .!= 0, size(u))
else
return DifferentialVarsUndefined()
end
else
return nothing
end
end

Expand Down
19 changes: 10 additions & 9 deletions lib/OrdinaryDiffEqRosenbrock/test/dae_rosenbrock_ad_tests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
using OrdinaryDiffEqRosenbrock, LinearAlgebra, ForwardDiff, Test
using OrdinaryDiffEqNonlinearSolve: BrownFullBasicInit, ShampineCollocationInit
using ADTypes: AutoForwardDiff, AutoFiniteDiff

afd_cs3 = AutoForwardDiff(chunksize=3)
function rober(du, u, p, t)
y₁, y₂, y₃ = u
k₁, k₂, k₃ = p
Expand All @@ -16,25 +18,24 @@ function rober(u, p, t)
k₁ * y₁ - k₃ * y₂ * y₃ - k₂ * y₂^2,
y₁ + y₂ + y₃ - 1]
end
M = [1.0 0 0
0 1.0 0
0 0 0]
roberf = ODEFunction(rober, mass_matrix = M)
roberf_oop = ODEFunction{false}(rober, mass_matrix = M)
M = Diagonal([1.0, 1.0, 0.0])
roberf = ODEFunction{true, SciMLBase.AutoSpecialize}(rober, mass_matrix = M)
roberf_oop = ODEFunction{false, SciMLBase.AutoSpecialize}(rober, mass_matrix = M)
prob_mm = ODEProblem(roberf, [1.0, 0.0, 0.2], (0.0, 1e5), (0.04, 3e7, 1e4))
prob_mm_oop = ODEProblem(roberf_oop, [1.0, 0.0, 0.2], (0.0, 1e5), (0.04, 3e7, 1e4))
sol = solve(prob_mm, Rodas5P(), reltol = 1e-8, abstol = 1e-8)
sol = solve(prob_mm_oop, Rodas5P(), reltol = 1e-8, abstol = 1e-8)
# Both should be inferrable so long as AutoSpecialize is used...
@test_broken sol = @inferred solve(prob_mm, Rodas5P(), reltol = 1e-8, abstol = 1e-8)
sol = @inferred solve(prob_mm_oop, Rodas5P(), reltol = 1e-8, abstol = 1e-8)

# These tests flex differentiation of the solver and through the initialization
# To only test the solver part and isolate potential issues, set the initialization to consistent
@testset "Inplace: $(isinplace(_prob)), DAEProblem: $(_prob isa DAEProblem), BrownBasic: $(initalg isa BrownFullBasicInit), Autodiff: $autodiff" for _prob in [
prob_mm, prob_mm_oop],
initalg in [BrownFullBasicInit(), ShampineCollocationInit()], autodiff in [true, false]
initalg in [BrownFullBasicInit(), ShampineCollocationInit()], autodiff in [AutoForwardDiff(chunksize=3), AutoFiniteDiff()]

alg = Rodas5P(; autodiff)
function f(p)
sol = solve(remake(_prob, p = p), alg, abstol = 1e-14,
sol = @inferred solve(remake(_prob, p = p), alg, abstol = 1e-14,
reltol = 1e-14, initializealg = initalg)
sum(sol)
end
Expand Down
22 changes: 12 additions & 10 deletions test/interface/mass_matrix_tests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
using OrdinaryDiffEq, Test, LinearAlgebra, Statistics
using OrdinaryDiffEqCore
using OrdinaryDiffEqNonlinearSolve: NLFunctional, NLAnderson, NLNewton
using LinearAlgebra: Diagonal
using ADTypes: AutoForwardDiff

# create mass matrix problems
function make_mm_probs(mm_A, ::Val{iip}) where {iip}
Expand Down Expand Up @@ -194,11 +196,10 @@ end
u0 = [0.0, 1.0]
tspan = (0.0, 1.0)

M = fill(0.0, 2, 2)
M[1, 1] = 1.0
M = Diagonal([1.0, 0.0])

m_ode_prob = ODEProblem(ODEFunction(f!; mass_matrix = M), u0, tspan)
@test_nowarn sol = solve(m_ode_prob, Rosenbrock23())
@test_nowarn sol = @inferred solve(m_ode_prob, Rosenbrock23(autodiff=AutoForwardDiff(chunksize=2)))

M = [0.637947 0.637947
0.637947 0.637947]
Expand Down Expand Up @@ -323,14 +324,15 @@ function dynamics(u, p, t)
end

x0 = zeros(n, n)
M = zeros(n * n) |> Diagonal |> Matrix
M = zeros(n * n) |> Diagonal
M[1, 1] = true # zero mass matrix breaks rosenbrock
f = ODEFunction(dynamics!, mass_matrix = M)
f = ODEFunction{true, SciMLBase.AutoSpecialize}(dynamics!, mass_matrix = M)
tspan = (0, 10.0)
adalg = AutoForwardDiff(chunksize=n)
prob = ODEProblem(f, x0, tspan)
foop = ODEFunction(dynamics, mass_matrix = M)
foop = ODEFunction{false, SciMLBase.AutoSpecialize}(dynamics, mass_matrix = M)
proboop = ODEProblem(f, x0, tspan)
sol = solve(prob, Rosenbrock23())
sol = solve(prob, Rodas4(), initializealg = ShampineCollocationInit())
sol = solve(proboop, Rodas5())
sol = solve(proboop, Rodas4(), initializealg = ShampineCollocationInit())
@test_broken sol = @inferred solve(prob, Rosenbrock23(autodiff=adalg))
@test_broken sol = @inferred solve(prob, Rodas4(autodiff=adalg), initializealg = ShampineCollocationInit())
@test_broken sol = @inferred solve(proboop, Rodas5())
@test_broken sol = @inferred solve(proboop, Rodas4(), initializealg = ShampineCollocationInit())
Loading