Skip to content

Commit fd0912f

Browse files
Merge pull request #3673 from AayushSabharwal/as/fix-promotion
[backport-v9] fix: fix type promotion in `late_binding_update_u0_p` with non-dual types
2 parents f35c378 + 3381916 commit fd0912f

File tree

2 files changed

+41
-13
lines changed

2 files changed

+41
-13
lines changed

src/systems/nonlinear/initializesystem.jl

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -598,24 +598,42 @@ function SciMLBase.remake_initialization_data(
598598
return SciMLBase.remake_initialization_data(sys, odefn, newu0, t0, newp, newu0, newp)
599599
end
600600

601-
function promote_u0_p(u0, p::MTKParameters, t0)
602-
u0 = DiffEqBase.promote_u0(u0, p.tunable, t0)
603-
u0 = DiffEqBase.promote_u0(u0, p.initials, t0)
601+
promote_type_with_nothing(::Type{T}, ::Nothing) where {T} = T
602+
promote_type_with_nothing(::Type{T}, ::SizedVector{0}) where {T} = T
603+
function promote_type_with_nothing(::Type{T}, ::AbstractArray{T2}) where {T, T2}
604+
promote_type(T, T2)
605+
end
604606

605-
if !isempty(p.tunable)
606-
tunables = DiffEqBase.promote_u0(p.tunable, u0, t0)
607-
p = SciMLStructures.replace(SciMLStructures.Tunable(), p, tunables)
608-
end
609-
if !isempty(p.initials)
610-
initials = DiffEqBase.promote_u0(p.initials, u0, t0)
611-
p = SciMLStructures.replace(SciMLStructures.Initials(), p, initials)
607+
promote_with_nothing(::Type, ::Nothing) = nothing
608+
promote_with_nothing(::Type, x::SizedVector{0}) = x
609+
promote_with_nothing(::Type{T}, x::AbstractArray{T}) where {T} = x
610+
function promote_with_nothing(::Type{T}, x::AbstractArray{T2}) where {T, T2}
611+
if ArrayInterface.ismutable(x)
612+
y = similar(x, T)
613+
copyto!(y, x)
614+
return y
615+
else
616+
yT = similar_type(x, T)
617+
return yT(x)
612618
end
613-
614-
return u0, p
619+
end
620+
function promote_with_nothing(::Type{T}, p::MTKParameters) where {T}
621+
tunables = promote_with_nothing(T, p.tunable)
622+
p = SciMLStructures.replace(SciMLStructures.Tunable(), p, tunables)
623+
initials = promote_with_nothing(T, p.initials)
624+
p = SciMLStructures.replace(SciMLStructures.Initials(), p, initials)
625+
return p
615626
end
616627

617628
function promote_u0_p(u0, p, t0)
618-
return DiffEqBase.promote_u0(u0, p, t0), DiffEqBase.promote_u0(p, u0, t0)
629+
T = Union{}
630+
T = promote_type_with_nothing(T, u0)
631+
T = promote_type_with_nothing(T, p.tunable)
632+
T = promote_type_with_nothing(T, p.initials)
633+
634+
u0 = promote_with_nothing(T, u0)
635+
p = promote_with_nothing(T, p)
636+
return u0, p
619637
end
620638

621639
function SciMLBase.late_binding_update_u0_p(

test/initial_values.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,3 +362,13 @@ end
362362
@test state_values(initdata.initializeprob) isa SVector
363363
@test parameter_values(initdata.initializeprob) isa SVector
364364
end
365+
366+
@testset "Type promotion of `p` works with non-dual types" begin
367+
@variables x(t) y(t)
368+
@mtkbuild sys = ODESystem([D(x) ~ x + y, x^3 + y^3 ~ 5], t; guesses = [y => 1.0])
369+
prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0))
370+
prob2 = remake(prob; u0 = BigFloat.(prob.u0))
371+
@test prob2.p.initials isa Vector{BigFloat}
372+
sol = solve(prob2)
373+
@test SciMLBase.successful_retcode(sol)
374+
end

0 commit comments

Comments
 (0)