Skip to content

Commit 2c75502

Browse files
Merge pull request #3702 from AayushSabharwal/as/nonnumeric-init
fix: fix `get_mtkparameters_reconstructor` handling of nonnumerics
2 parents 263c870 + 81596b4 commit 2c75502

File tree

3 files changed

+54
-10
lines changed

3 files changed

+54
-10
lines changed

src/systems/problem_utils.jl

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -659,7 +659,8 @@ end
659659
$(TYPEDEF)
660660
661661
A callable struct which applies `p_constructor` to possibly nested arrays. It also
662-
ensures that views (including nested ones) are concretized.
662+
ensures that views (including nested ones) are concretized. This is implemented manually
663+
of using `narrow_buffer_type` to preserve type-stability.
663664
"""
664665
struct PConstructorApplicator{F}
665666
p_constructor::F
@@ -669,10 +670,18 @@ function (pca::PConstructorApplicator)(x::AbstractArray)
669670
pca.p_constructor(x)
670671
end
671672

673+
function (pca::PConstructorApplicator)(x::AbstractArray{Bool})
674+
pca.p_constructor(BitArray(x))
675+
end
676+
672677
function (pca::PConstructorApplicator{typeof(identity)})(x::SubArray)
673678
collect(x)
674679
end
675680

681+
function (pca::PConstructorApplicator{typeof(identity)})(x::SubArray{Bool})
682+
BitArray(x)
683+
end
684+
676685
function (pca::PConstructorApplicator{typeof(identity)})(x::SubArray{<:AbstractArray})
677686
collect(pca.(x))
678687
end
@@ -695,6 +704,7 @@ takes a value provider of `srcsys` and a value provider of `dstsys` and returns
695704
"""
696705
function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::AbstractSystem;
697706
initials = false, unwrap_initials = false, p_constructor = identity)
707+
_p_constructor = p_constructor
698708
p_constructor = PConstructorApplicator(p_constructor)
699709
# if we call `getu` on this (and it were able to handle empty tuples) we get the
700710
# fields of `MTKParameters` except caches.
@@ -748,14 +758,24 @@ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::Abstrac
748758
Base.Fix1(broadcast, p_constructor)
749759
getu(srcsys, syms[3])
750760
end
751-
rest_getters = map(Base.tail(Base.tail(Base.tail(syms)))) do buf
752-
if buf == ()
753-
return Returns(())
754-
else
755-
return Base.Fix1(broadcast, p_constructor) getu(srcsys, buf)
756-
end
761+
const_getter = if syms[4] == ()
762+
Returns(())
763+
else
764+
Base.Fix1(broadcast, p_constructor) getu(srcsys, syms[4])
757765
end
758-
getters = (tunable_getter, initials_getter, discs_getter, rest_getters...)
766+
nonnumeric_getter = if syms[5] == ()
767+
Returns(())
768+
else
769+
ic = get_index_cache(dstsys)
770+
buftypes = Tuple(map(ic.nonnumeric_buffer_sizes) do bufsize
771+
Vector{bufsize.type}
772+
end)
773+
# nonnumerics retain the assigned buffer type without narrowing
774+
Base.Fix1(broadcast, _p_constructor)
775+
Base.Fix1(Broadcast.BroadcastFunction(call), buftypes) getu(srcsys, syms[5])
776+
end
777+
getters = (
778+
tunable_getter, initials_getter, discs_getter, const_getter, nonnumeric_getter)
759779
getter = let getters = getters
760780
function _getter(valp, initprob)
761781
oldcache = parameter_values(initprob).caches
@@ -768,6 +788,10 @@ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::Abstrac
768788
return getter
769789
end
770790

791+
function call(f, args...)
792+
f(args...)
793+
end
794+
771795
"""
772796
$(TYPEDSIGNATURES)
773797

test/code_generation.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ end
8989
end
9090
@mtkbuild sys = ODESystem([D(x) ~ y[1] + y[2], y ~ foo(x)], t)
9191
@test length(equations(sys)) == 1
92-
@test length(observed(sys)) == 3
92+
@test length(ModelingToolkit.observed(sys)) == 3
9393
prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [foo => _tmp_fn2])
9494
val[] = 0
9595
@test_nowarn prob.f(prob.u0, prob.p, 0.0)
@@ -101,7 +101,7 @@ end
101101
@mtkbuild sys = ODESystem(
102102
[D(y) ~ foo(x), D(x) ~ sum(y), zeros(2) ~ foo(prod(z))], t)
103103
@test length(equations(sys)) == 5
104-
@test length(observed(sys)) == 0
104+
@test length(ModelingToolkit.observed(sys)) == 0
105105
prob = ODEProblem(
106106
sys, [y => ones(2), z => 2ones(2), x => 3.0], (0.0, 1.0), [foo => _tmp_fn2])
107107
val[] = 0

test/initializationsystem.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1671,3 +1671,23 @@ end
16711671
sol = solve(prob, Tsit5())
16721672
@test SciMLBase.successful_retcode(sol)
16731673
end
1674+
1675+
@testset "Nonnumerics aren't narrowed" begin
1676+
@mtkmodel Foo begin
1677+
@variables begin
1678+
x(t) = 1.0
1679+
end
1680+
@parameters begin
1681+
p::AbstractString
1682+
r = 1.0
1683+
end
1684+
@equations begin
1685+
D(x) ~ r * x
1686+
end
1687+
end
1688+
@mtkbuild sys = Foo(p = "a")
1689+
prob = ODEProblem(sys, [], (0.0, 1.0))
1690+
@test prob.p.nonnumeric[1] isa Vector{AbstractString}
1691+
integ = init(prob)
1692+
@test integ.p.nonnumeric[1] isa Vector{AbstractString}
1693+
end

0 commit comments

Comments
 (0)