Skip to content

Commit a688533

Browse files
Merge pull request SciML#2941 from ChrisRackauckas-Claude/claude/rebase-hw-dae-gpu-pr2911
fix GPU compat of BrownFullBasicInit (rebased from SciML#2911)
2 parents 909e799 + b600432 commit a688533

4 files changed

Lines changed: 76 additions & 3 deletions

File tree

lib/OrdinaryDiffEqNonlinearSolve/src/initialize_dae.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -396,8 +396,9 @@ function _initialize_dae!(integrator::OrdinaryDiffEqCore.ODEIntegrator, prob::OD
396396
M = integrator.f.mass_matrix
397397
M isa UniformScaling && return
398398
update_coefficients!(M, u, p, t)
399-
algebraic_vars = [all(iszero, x) for x in eachcol(M)]
400-
algebraic_eqs = [all(iszero, x) for x in eachrow(M)]
399+
algebraic_vars = vec(all(iszero, M, dims = 1))
400+
algebraic_eqs = vec(all(iszero, M, dims = 2))
401+
401402
(iszero(algebraic_vars) || iszero(algebraic_eqs)) && return
402403
tmp = get_tmp_cache(integrator)[1]
403404

@@ -456,7 +457,7 @@ function _initialize_dae!(integrator::OrdinaryDiffEqCore.ODEIntegrator, prob::OD
456457
nlsolve = default_nlsolve(alg.nlsolve, isinplace, u, nlprob, isAD)
457458

458459
nlsol = solve(nlprob, nlsolve; abstol = alg.abstol, reltol = integrator.opts.reltol)
459-
alg_u .= nlsol
460+
alg_u .= nlsol.u
460461

461462
recursivecopy!(integrator.uprev, integrator.u)
462463
if alg_extrapolates(integrator.alg)

test/gpu/Project.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,17 @@ FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
77
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
88
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
99
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
10+
OrdinaryDiffEqNonlinearSolve = "127b3ac7-2247-4354-8eb6-78cf4e7c58e8"
1011
OrdinaryDiffEqRKIP = "a4daff8c-1d43-4ff3-8eff-f78720aeecdc"
12+
OrdinaryDiffEqRosenbrock = "43230ef6-c299-4910-a778-202eb28ce4ce"
1113
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
1214
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1315
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
1416

1517
[sources]
18+
OrdinaryDiffEqNonlinearSolve = {path = "../../lib/OrdinaryDiffEqNonlinearSolve"}
1619
OrdinaryDiffEqRKIP = {path = "../../lib/OrdinaryDiffEqRKIP"}
20+
OrdinaryDiffEqRosenbrock = {path = "../../lib/OrdinaryDiffEqRosenbrock"}
1721

1822
[compat]
1923
Adapt = "4"
@@ -24,7 +28,9 @@ FastBroadcast = "0.3"
2428
FFTW = "1.8"
2529
FillArrays = "1"
2630
OrdinaryDiffEq = "6"
31+
OrdinaryDiffEqNonlinearSolve = "1"
2732
OrdinaryDiffEqRKIP = "1"
33+
OrdinaryDiffEqRosenbrock = "1"
2834
RecursiveArrayTools = "3"
2935
SciMLBase = "2.99"
3036
SciMLOperators = "1.3"

test/gpu/simple_dae.jl

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
using OrdinaryDiffEqRosenbrock
2+
using OrdinaryDiffEqNonlinearSolve
3+
using CUDA
4+
using LinearAlgebra
5+
using Adapt
6+
using SparseArrays
7+
using Test
8+
9+
#=
10+
du[1] = -u[1]
11+
du[2] = -0.5*u[2]
12+
0 = u[1] + u[2] - u[3]
13+
0 = -u[1] + u[2] - u[4]
14+
=#
15+
16+
function dae!(du, u, p, t)
17+
mul!(du, p, u)
18+
end
19+
20+
p = [-1 0 0 0
21+
1 -0.5 0 0
22+
1 1 -1 0
23+
-1 1 0 -1]
24+
25+
# mass_matrix = [1 0 0 0
26+
# 0 1 0 0
27+
# 0 0 0 0
28+
# 0 0 0 0]
29+
mass_matrix = Diagonal([1, 1, 0, 0])
30+
jac_prototype = sparse(map(x -> iszero(x) ? 0.0 : 1.0, p))
31+
32+
u0 = [1.0, 1.0, 0.5, 0.5] # force init
33+
odef = ODEFunction(dae!, mass_matrix = mass_matrix, jac_prototype = jac_prototype)
34+
35+
tspan = (0.0, 5.0)
36+
prob = ODEProblem(odef, u0, tspan, p)
37+
sol = solve(prob, Rodas5P())
38+
39+
# gpu version
40+
mass_matrix_d = adapt(CuArray, mass_matrix)
41+
42+
# TODO: jac_prototype fails
43+
# jac_prototype_d = adapt(CuArray, jac_prototype)
44+
# jac_prototype_d = CUDA.CUSPARSE.CuSparseMatrixCSR(jac_prototype)
45+
jac_prototype_d = nothing
46+
47+
u0_d = adapt(CuArray, u0)
48+
p_d = adapt(CuArray, p)
49+
odef_d = ODEFunction(dae!, mass_matrix = mass_matrix_d, jac_prototype = jac_prototype_d)
50+
prob_d = ODEProblem(odef_d, u0_d, tspan, p_d)
51+
sol_d = solve(prob_d, Rodas5P())
52+
53+
@testset "Test constraints in GPU sol" begin
54+
for t in sol_d.t
55+
u = Vector(sol_d(t))
56+
@test isapprox(u[1] + u[2], u[3]; atol = 1e-6)
57+
@test isapprox(-u[1] + u[2], u[4]; atol = 1e-6)
58+
end
59+
end
60+
61+
@testset "Compare GPU to CPU solution" begin
62+
for t in tspan[begin]:0.1:tspan[end]
63+
@test Vector(sol_d(t)) sol(t)
64+
end
65+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@ end
206206
@time @safetestset "Reaction-Diffusion Stiff Solver GPU" include("gpu/reaction_diffusion_stiff.jl")
207207
@time @safetestset "Scalar indexing bug bypass" include("gpu/hermite_test.jl")
208208
@time @safetestset "RKIP Semilinear PDE GPU" include("gpu/rkip_semilinear_pde.jl")
209+
@time @safetestset "simple dae on GPU" include("gpu/simple_dae.jl")
209210
end
210211

211212
if !is_APPVEYOR && GROUP == "QA"

0 commit comments

Comments
 (0)