Skip to content

Commit 8a19a49

Browse files
fix: copy initials to u0 if u0 not provided to remake
1 parent 0a9134e commit 8a19a49

File tree

2 files changed

+34
-0
lines changed

2 files changed

+34
-0
lines changed

src/systems/nonlinear/initializesystem.jl

+13
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,7 @@ function SciMLBase.remake_initialization_data(
543543
if u0 === missing && p === missing
544544
return odefn.initialization_data
545545
end
546+
546547
if !(eltype(u0) <: Pair) && !(eltype(p) <: Pair)
547548
oldinitdata = odefn.initialization_data
548549
oldinitdata === nothing && return nothing
@@ -658,6 +659,18 @@ function SciMLBase.late_binding_update_u0_p(
658659
prob, sys::AbstractSystem, u0, p, t0, newu0, newp)
659660
supports_initialization(sys) || return newu0, newp
660661
u0 === missing && return newu0, (p === missing ? copy(newp) : newp)
662+
# If the user passes `p` to `remake` but not `u0` and `u0` isn't empty,
663+
# and if the system supports initialization (so it has initial parameters),
664+
# and if the initialization solves for `u0`,
665+
# THEN copy the values of `Initial`s to `newu0`.
666+
if u0 === missing && newu0 !== nothing && p !== missing && supports_initialization(sys) && prob.f.initialization_data !== nothing && prob.f.initialization_data.initializeprobmap !== nothing
667+
if ArrayInterface.ismutable(newu0)
668+
copyto!(newu0, getu(sys, Initial.(unknowns(sys)))(newp))
669+
else
670+
T = StaticArrays.similar_type(newu0)
671+
newu0 = T(getu(sys, Initial.(unknowns(sys)))(newp))
672+
end
673+
end
661674
# non-symbolic u0 updates initials...
662675
if !(eltype(u0) <: Pair)
663676
# if `p` is not provided or is symbolic

test/initializationsystem.jl

+21
Original file line numberDiff line numberDiff line change
@@ -1496,3 +1496,24 @@ end
14961496
@test integ3.u [2.0, 3.0]
14971497
@test integ3.ps[c1] 2.0
14981498
end
1499+
1500+
@testset "Issue#3570: `Initial`s are copied to `u0` if `u0` not provided to `remake`" begin
1501+
@parameters g
1502+
@variables x(t) [state_priority = 10] y(t) λ(t)
1503+
eqs = [D(D(x)) ~ λ * x
1504+
D(D(y)) ~ λ * y - g
1505+
x^2 + y^2 ~ 1]
1506+
@mtkbuild pend = ODESystem(eqs, t)
1507+
1508+
prob = ODEProblem(
1509+
pend, [x => (2 / 2)], (0.0, 1.5), [g => 1], guesses ==> 1, y => 2 / 2])
1510+
sol = solve(prob)
1511+
1512+
setter = setsym_oop(prob, [Initial(x)])
1513+
(u0, p) = setter(prob, [0.8])
1514+
1515+
new_prob = remake(prob; p, initializealg = BrownFullBasicInit())
1516+
@test new_prob[x] 0.8
1517+
new_sol = solve(new_prob)
1518+
@test new_sol[x, 1] 0.8
1519+
end

0 commit comments

Comments
 (0)