Skip to content

Support in-place interpolation of symbolic idxs #988

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 42 additions & 1 deletion src/solutions/ode_solutions.jl
Original file line number Diff line number Diff line change
@@ -213,6 +213,7 @@ function is_discrete_expression(indp, expr)
length(ts_idxs) > 1 || length(ts_idxs) == 1 && only(ts_idxs) != ContinuousTimeseries()
end

# These are the two main documented user-facing interpolation API functions (out-of-place and in-place versions)
function (sol::AbstractODESolution)(t, ::Type{deriv} = Val{0}; idxs = nothing,
continuity = :left) where {deriv}
if t isa IndexedClock
@@ -225,9 +226,12 @@ function (sol::AbstractODESolution)(v, t, ::Type{deriv} = Val{0}; idxs = nothing
if t isa IndexedClock
t = canonicalize_indexed_clock(t, sol)
end
sol.interp(v, t, idxs, deriv, sol.prob.p, continuity)
sol(v, t, deriv, idxs, continuity)
end

# Below are many internal dispatches for different combinations of arguments to the main API
# TODO: could use a clever rewrite, since a lot of reused code has accumulated
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue is that the symbolic dispatch is kept open, since there's more than one issymbolic type. We should union that in the future, but it needs JuliaSymbolics/SymbolicUtils.jl#737


function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs::Nothing,
continuity) where {deriv}
sol.interp(t, idxs, deriv, sol.prob.p, continuity)
@@ -365,6 +369,43 @@ function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv},
return DiffEqArray(u, t, p, sol; discretes)
end

function (sol::AbstractODESolution)(
v, t::Union{Number, AbstractVector{<:Number}}, ::Type{deriv},
idxs::Union{Nothing, Integer, AbstractArray{<:Integer}}, continuity) where {deriv}
return sol.interp(v, t, idxs, deriv, sol.prob.p, continuity)
end
function (sol::AbstractODESolution)(
v, t::Union{Number, AbstractVector{<:Number}}, ::Type{deriv}, idxs,
continuity) where {deriv}
if idxs isa AbstractArray && any(idx -> idx == NotSymbolic(), symbolic_type.(idxs)) ||
!(idxs isa AbstractArray) && symbolic_type(idxs) == NotSymbolic()
error("Incorrect specification of `idxs`")
end
error_if_observed_derivative(sol, idxs, deriv)
p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing
getter = getsym(sol, idxs) # TODO: breaks type inference and allocates
if is_parameter_timeseries(sol) == NotTimeseries() || !is_discrete_expression(sol, idxs)
u = zeros(eltype(sol), size(sol)[1])
if t isa AbstractVector
for ti in eachindex(t)
sol.interp(u, t[ti], nothing, deriv, p, continuity)
state = ProblemState(; u = u, p = p, t = t[ti])
if eltype(v) <: Number
v[ti] = getter(state)
else
v[ti] .= getter(state)
end
end
else # t isa Number
sol.interp(u, t, nothing, deriv, p, continuity)
state = ProblemState(; u = u, p = p, t = t)
v .= getter(state)
end
return v
end
error("In-place interpolation with discretes is not implemented.")
end

struct DDESolutionHistoryWrapper{T}
sol::T
end
31 changes: 30 additions & 1 deletion test/downstream/solution_interface.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using ModelingToolkit, OrdinaryDiffEq, RecursiveArrayTools, StochasticDiffEq, Test
using StochasticDiffEq
using SymbolicIndexingInterface
using ModelingToolkit: t_nounits as t, D_nounits as D
using ModelingToolkit: observed, t_nounits as t, D_nounits as D
using Plots: Plots, plot

### Tests on non-layered model (everything should work). ###
@@ -148,6 +148,35 @@ sol9 = sol(0.0:1.0:10.0, idxs = 2)
sol10 = sol(0.1, idxs = 2)
@test sol10 isa Real

# in-place interpolation with single (unknown) symbolic index
ts = 0.0:0.1:10.0
out = zeros(eltype(sol), size(ts))
idxs = unknowns(sys)[1]
@test sol(out, ts; idxs) == sol(ts; idxs)
@test (@allocated sol(out, ts; idxs)) < (@allocated sol(ts; idxs))
@test_nowarn @inferred sol(out, ts; idxs)

# in-place interpolation with single (observed) symbolic index
idxs = observed(sys)[1].lhs
@test sol(out, ts; idxs) == sol(ts; idxs)
@test (@allocated sol(out, ts; idxs)) < (@allocated sol(ts; idxs))
@test_nowarn @inferred sol(out, ts; idxs)

# in-place interpolation with multiple (unknown+observed) symbolic indices
idxs = [unknowns(sys)[1], observed(sys)[1].lhs]
out = [zeros(eltype(sol), size(idxs)) for _ in eachindex(ts)]
@test sol(out, ts; idxs) == sol(ts; idxs).u
@test (@allocated sol(out, ts; idxs)) < (@allocated sol(ts; idxs))
@test_nowarn @inferred sol(out, ts; idxs)

# same as above, but with one time value only
@test sol(out[1], ts[1]; idxs) == sol(ts[1]; idxs)
#@test (@allocated sol(out[1], ts[1]; idxs)) < (@allocated sol(ts[1]; idxs)) # TODO: reduce allocations and fix
@test_nowarn @inferred sol(out[1], ts[1]; idxs)

idxs = [unknowns(sys)[1], 1]
@test_throws "Incorrect specification of `idxs`" sol(out, ts; idxs)

@testset "Plot idxs" begin
@variables x(t) y(t)
@parameters p