Skip to content

Commit 14a67d3

Browse files
Merge pull request #561 from ChrisRackauckas-Claude/fix-basis-parameter-unwrap-559
Fix symbolic to numeric conversion using Symbolics.unwrap (#559)
2 parents 6caca6b + e7acb81 commit 14a67d3

File tree

2 files changed

+66
-6
lines changed

2 files changed

+66
-6
lines changed

src/basis/type.jl

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -529,14 +529,17 @@ If no default value is stored, returns `zero(T)` where `T` is the `symtype` of t
529529
## Note
530530
531531
This extends `getmetadata` in a way that all parameters have a numeric value.
532+
Values are unwrapped from symbolic wrappers to ensure compatibility with ODEProblem.
532533
"""
533534
function get_parameter_values(x::Basis)
534535
map(parameters(x)) do p
535-
if hasmetadata(p, Symbolics.VariableDefaultValue)
536-
return Symbolics.getdefaultval(p)
536+
val = if hasmetadata(p, Symbolics.VariableDefaultValue)
537+
Symbolics.getdefaultval(p)
537538
else
538-
return zero(Symbolics.symtype(p))
539+
zero(Symbolics.symtype(p))
539540
end
541+
# Unwrap symbolic values to numeric values for use in ODEProblem
542+
return Symbolics.unwrap(val)
540543
end
541544
end
542545

@@ -549,14 +552,17 @@ If no default value is stored, returns `zero(T)` where `T` is the `symtype` of t
549552
## Note
550553
551554
This extends `getmetadata` in a way that all parameters have a numeric value.
555+
Values are unwrapped from symbolic wrappers to ensure compatibility with ODEProblem.
552556
"""
553557
function get_parameter_map(x::Basis)
554558
map(parameters(x)) do p
555-
if hasmetadata(p, Symbolics.VariableDefaultValue)
556-
return p => Symbolics.getdefaultval(p)
559+
val = if hasmetadata(p, Symbolics.VariableDefaultValue)
560+
Symbolics.getdefaultval(p)
557561
else
558-
return p => zero(Symbolics.symtype(p))
562+
zero(Symbolics.symtype(p))
559563
end
564+
# Unwrap symbolic values to numeric values for use in ODEProblem
565+
return p => Symbolics.unwrap(val)
560566
end
561567
end
562568

test/basis/basis.jl

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,3 +170,57 @@ end
170170
@test get_parameter_values(b) == [1.0; 2.0]
171171
@test last.(get_parameter_map(b)) == [1.0; 2.0]
172172
end
173+
174+
@testset "ODEProblem from Basis (Issue #559)" begin
175+
# Regression test for issue #559: solve throws MethodError when creating
176+
# ODEProblem from Basis due to symbolic to numeric conversion issues
177+
using OrdinaryDiffEqTsit5
178+
179+
# Create a simple basis with parameters that have no default values
180+
@variables u[1:2]
181+
@parameters w[1:2]
182+
u = collect(u)
183+
w = collect(w)
184+
185+
# Create a basis with parameters without default values
186+
# This tests the zero(Symbolics.symtype(p)) code path
187+
h = [u[1]^2 + w[1] * u[2]; sin(w[2] * u[1])]
188+
basis = Basis(h, u, parameters = w)
189+
190+
# Test that get_parameter_values returns unwrapped numeric values, not symbolic
191+
params = get_parameter_values(basis)
192+
@test params isa Vector
193+
@test all(p -> !(p isa Num), params) # Should not be Num/symbolic
194+
@test all(iszero, params) # Parameters without defaults should be zero
195+
196+
# Test that get_parameter_map also returns unwrapped numeric values
197+
param_map = get_parameter_map(basis)
198+
@test all(pair -> !(last(pair) isa Num), param_map)
199+
200+
# Test that we can create an ODEProblem from the basis
201+
# This is the key test from issue #559 - should not throw MethodError
202+
# about "Cannot convert BasicSymbolic{Real} to Float64"
203+
u0 = [1.0, 2.0]
204+
tspan = (0.0, 0.1) # Very short timespan
205+
p_values = [0.01, 0.01] # Very small parameter values
206+
recovered_model = ODEProblem(basis, u0, tspan, p_values)
207+
@test recovered_model isa ODEProblem
208+
209+
# Test that we can initialize the integrator without the symbolic conversion error
210+
sol = solve(recovered_model, Tsit5(), save_everystep = false)
211+
212+
# Also test with parameters that have default values
213+
@parameters w2[1:2] = [1.5, 2.5]
214+
w2 = collect(w2)
215+
h2 = [u[1]^2 + w2[1] * u[2]; sin(w2[2] * u[1])]
216+
basis2 = Basis(h2, u, parameters = w2)
217+
218+
# Test that get_parameter_values returns the default values unwrapped
219+
params2 = get_parameter_values(basis2)
220+
@test all(p -> !(p isa Num), params2)
221+
@test params2 [1.5, 2.5]
222+
223+
# Test creating ODEProblem with default parameter values
224+
recovered_model2 = ODEProblem(basis2, u0, tspan, params2)
225+
@test recovered_model2 isa ODEProblem
226+
end

0 commit comments

Comments
 (0)