Skip to content

Commit 093123a

Browse files
Move nonfulldiagonal_sparse test to OrdinaryDiffEqDifferentiation
Add a separate "Sparse" test group for OrdinaryDiffEqDifferentiation that uses its own environment (test/sparse/Project.toml) to handle the ComponentArrays dependency. The autosparse_detection and sparsediff tests remain in the global suite because SparseConnectivityTracer 0.6 conflicts with PreallocationTools in the current registry compat. This is a known ecosystem issue that will resolve when PreallocationTools updates its SparseConnectivityTracer compat bounds. Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 9539ed3 commit 093123a

File tree

5 files changed

+193
-4
lines changed

5 files changed

+193
-4
lines changed
Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,31 @@
11
using SafeTestsets
2+
using Pkg
23

34
const TEST_GROUP = get(ENV, "ODEDIFFEQ_TEST_GROUP", "ALL")
45

6+
function activate_sparse_env()
7+
Pkg.activate(joinpath(@__DIR__, "sparse"))
8+
# Develop the top-level OrdinaryDiffEq package (which pulls all subpackages)
9+
Pkg.develop(PackageSpec(path = dirname(dirname(dirname(@__DIR__)))))
10+
return Pkg.instantiate()
11+
end
12+
513
# Run QA tests (JET, Aqua)
6-
if TEST_GROUP != "Core" && isempty(VERSION.prerelease)
14+
if TEST_GROUP ("Core", "Sparse") && isempty(VERSION.prerelease)
715
@time @safetestset "JET Tests" include("jet.jl")
816
@time @safetestset "Aqua" include("qa.jl")
917
end
1018

1119
# Run functional tests
12-
if TEST_GROUP != "QA"
20+
if TEST_GROUP ("QA", "Sparse")
1321
@time @safetestset "OOP J_t Tracking" include("oop_jt_tracking_test.jl")
1422
@time @safetestset "Differentiation Trait Tests" include("differentiation_traits_tests.jl")
1523
@time @safetestset "Autodiff Error Tests" include("autodiff_error_tests.jl")
1624
@time @safetestset "No Jac Tests" include("nojac_tests.jl")
1725
end
26+
27+
# Run sparse tests (separate environment due to SparseConnectivityTracer/ComponentArrays dep conflicts)
28+
if TEST_GROUP == "Sparse" || TEST_GROUP == "ALL"
29+
activate_sparse_env()
30+
@time @safetestset "Non-Full Diagonal Sparsity Tests" include("sparse/nonfulldiagonal_sparse_tests.jl")
31+
end
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
[deps]
2+
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
3+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
4+
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
5+
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
6+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
7+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
8+
9+
[sources]
10+
OrdinaryDiffEq = {path = "../../../.."}
11+
12+
[compat]
13+
ComponentArrays = "0.15, 1"
14+
LinearSolve = "3"
15+
OrdinaryDiffEq = "6"
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
using OrdinaryDiffEq, SparseArrays, LinearSolve, LinearAlgebra
2+
using ComponentArrays
3+
4+
function enclosethetimedifferential(parameters::NamedTuple)::Function
5+
@info "Enclosing the time differential"
6+
7+
(; Δr, r_space, countorderapprox) = parameters.compute
8+
N = length(r_space)
9+
10+
function first_deriv(N)
11+
dx = 1 / (N + 1)
12+
du = -1 * ones(N - 1) # off diagonal
13+
du2 = ones(N - 1) # off diagonal
14+
diag = zeros(N)
15+
lower = spzeros(Float64, N)
16+
upper = spzeros(Float64, N)
17+
lower[1] = -1.0
18+
upper[end] = 1.0
19+
M = hcat(lower, sparse(diagm(-1 => du, 0 => diag, 1 => du2)), upper)
20+
return MatrixOperator(1 / dx * M)
21+
end
22+
23+
function second_deriv(N)
24+
dx = 1 / (N + 1)
25+
du = ones(N - 1) # off diagonal
26+
du2 = ones(N - 1) # off diagonal
27+
diag = -2 * ones(N)
28+
lower = spzeros(Float64, N)
29+
upper = spzeros(Float64, N)
30+
lower[1] = 1.0
31+
upper[end] = 1.0
32+
M = hcat(lower, sparse(diagm(-1 => du, 0 => diag, 1 => du2)), upper)
33+
return MatrixOperator(1 / dx^2 * M)
34+
end
35+
36+
function extender(N)
37+
dx = 1 / (N + 1)
38+
diag = ones(N)
39+
lower = spzeros(Float64, N)
40+
upper = spzeros(Float64, N)
41+
lower[1] = 1.0
42+
upper[end] = 1.0
43+
M = vcat(
44+
transpose(lower),
45+
sparse(diagm(diag)),
46+
transpose(upper)
47+
)
48+
return MatrixOperator(1 / dx^2 * M)
49+
end
50+
51+
bc_handler = extender(N)
52+
53+
= first_deriv(N) * bc_handler
54+
Δ = second_deriv(N) * bc_handler
55+
56+
bc_x = zeros(Real, N)
57+
bc_xx = zeros(Real, N)
58+
59+
function timedifferentialclosure!(du, u, p, t)
60+
(;
61+
α, D, v, k_p, V_c, Q_l, Q_r, V_b,
62+
S, Lm, Dm, V_v,
63+
) = p
64+
65+
c = u[1:(end - 3)]
66+
c_v = u[end - 2]
67+
c_c = u[end - 1]
68+
c_b = u[end]
69+
70+
J_B0 = (Dm / Lm) ** c_v - c[1])
71+
J_BL = (Dm / Lm) * (c[end] - α * c_c)
72+
grad_0 = (v ./ D) .* c[1] .- J_B0 ./ D
73+
grad_L = (v ./ D) .* c[end] .- J_BL ./ D
74+
75+
bc_x[1] = grad_0 / 2
76+
bc_x[end] = grad_L / 2
77+
grad_c =* c + bc_x
78+
79+
bc_xx[1] = -grad_0 / Δr
80+
bc_xx[end] = grad_L / Δr
81+
Lap_c = Δ * c + bc_xx
82+
83+
C = sum(Δr .* S * (k_p * (c .- c_b)))
84+
85+
dc_dt = D * Lap_c - v * grad_c .- k_p * (c .- c_b)
86+
du[1:(end - 3)] = dc_dt[1:end]
87+
88+
dcv_dt = -S * J_B0 / V_v - (Q_l / V_v) * c_v
89+
du[end - 2] = dcv_dt
90+
91+
dcc_dt = S * α * J_BL / V_c + (Q_l / V_c) * c_v - (Q_l / V_c) * c_c
92+
du[end - 1] = dcc_dt
93+
94+
dcb_dt = (Q_l / V_b) * c_c + C / V_b
95+
return du[end] = dcb_dt
96+
end
97+
98+
return timedifferentialclosure!
99+
end
100+
101+
prior = ComponentArray(;
102+
α = 0.2,
103+
D = 0.46,
104+
v = 0.0,
105+
k_p = 0.0,
106+
V_c = 18,
107+
Q_l = 20,
108+
Q_r = 3.6,
109+
V_b = 1490,
110+
S = 52,
111+
Lm = 0.05,
112+
Dm = 0.046,
113+
V_v = 18.0
114+
)
115+
116+
r_space = collect(range(0.0, 2.0, length = 15))
117+
computeparams = (
118+
Δr = r_space[2],
119+
r_space = r_space,
120+
countorderapprox = 2,
121+
)
122+
parameters = (
123+
prior = prior,
124+
compute = computeparams,
125+
)
126+
127+
dudt = enclosethetimedifferential(parameters)
128+
IC = ones(length(r_space) + 3)
129+
odeprob = ODEProblem(
130+
dudt,
131+
IC,
132+
(0, 2.1),
133+
parameters.prior
134+
);
135+
du0 = copy(odeprob.u0);
136+
# Hardcoded sparsity pattern for 15 spatial points + 3 state variables (18x18 matrix)
137+
I = [1, 2, 16, 18, 1, 2, 3, 18, 2, 3, 4, 18, 3, 4, 5, 18, 4, 5, 6, 18, 5, 6, 7, 18, 6, 7, 8, 18, 7, 8, 9, 18, 8, 9, 10, 18, 9, 10, 11, 18, 10, 11, 12, 18, 11, 12, 13, 18, 12, 13, 14, 18, 13, 14, 15, 18, 14, 15, 17, 18, 1, 16, 17, 15, 17, 18, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 18]
138+
J = [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8, 9, 9, 9, 9, 10, 10, 10, 10, 11, 11, 11, 11, 12, 12, 12, 12, 13, 13, 13, 13, 14, 14, 14, 14, 15, 15, 15, 15, 16, 16, 16, 17, 17, 17, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18]
139+
jac_sparsity = sparse(I, J, ones(Bool, length(I)), 18, 18);
140+
f = ODEFunction(
141+
dudt;
142+
jac_prototype = float.(jac_sparsity)
143+
);
144+
sparseodeprob = ODEProblem(
145+
f,
146+
odeprob.u0,
147+
(0, 2.1),
148+
parameters.prior
149+
);
150+
151+
solve(odeprob, TRBDF2());
152+
solve(sparseodeprob, TRBDF2());
153+
solve(sparseodeprob, Rosenbrock23(linsolve = KLUFactorization()));
154+
solve(sparseodeprob, KenCarp47(linsolve = KrylovJL_GMRES()));
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
[Core]
2+
versions = ["lts", "1.11", "1", "pre"]
3+
4+
[QA]
5+
versions = ["1"]
6+
7+
[Sparse]
8+
versions = ["1"]

test/runtests.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,6 @@ end
109109
@time @safetestset "No Index Tests" include("interface/noindex_tests.jl")
110110
@time @safetestset "Events + DAE addsteps Tests" include("interface/event_dae_addsteps.jl")
111111
@time @safetestset "Units Tests" include("interface/units_tests.jl")
112-
@time @safetestset "Non-Full Diagonal Sparsity Tests" include("interface/nonfulldiagonal_sparse.jl")
113112
@time @safetestset "DEVerbosity Tests" include("interface/verbosity.jl")
114113
end
115114

@@ -187,7 +186,6 @@ end
187186
activate_downstream_env()
188187
@time @safetestset "DelayDiffEq Tests" include("downstream/delaydiffeq.jl")
189188
@time @safetestset "Measurements Tests" include("downstream/measurements.jl")
190-
@time @safetestset "Sparse Diff Tests" include("downstream/sparsediff_tests.jl")
191189
@time @safetestset "Time derivative Tests" include("downstream/time_derivative_test.jl")
192190
end
193191

0 commit comments

Comments
 (0)