Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SimpleDiffEq"
uuid = "05bca326-078c-5bf0-a5bf-ce7c7982d7fd"
repo = "https://github.com/SciML/SimpleDiffEq.jl.git"
version = "1.12.0"
repo = "https://github.com/SciML/SimpleDiffEq.jl.git"

[deps]
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
Expand All @@ -15,6 +15,7 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
[compat]
DiffEqBase = "6.122"
ExplicitImports = "1.14.0"
JET = "0.11.3"
JLArrays = "0.1, 0.2, 0.3"
MuladdMacro = "0.2"
OrdinaryDiffEq = "6"
Expand All @@ -26,10 +27,11 @@ julia = "1.6"

[extras]
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["ExplicitImports", "JLArrays", "OrdinaryDiffEq", "SafeTestsets", "Test"]
test = ["ExplicitImports", "JET", "JLArrays", "OrdinaryDiffEq", "SafeTestsets", "Test"]
17 changes: 4 additions & 13 deletions src/tsit5/atsit5.jl
Original file line number Diff line number Diff line change
Expand Up @@ -692,25 +692,16 @@ end
# Interpolation
#######################################################################################
# Interpolation function, both OOP and IIP
@inline @muladd function (
integ::SAT5I{
IIP,
S,
T,
}
)(t::Real) where {
IIP,
S <:
AbstractArray{<:Number},
T,
}
@inline @muladd function (integ::SAT5I{IIP, S, T})(
t::Real
) where {IIP, S <: AbstractArray{<:Number}, T}
tnext, tprev, dt = integ.t, integ.tprev, integ.dt

θ = (t - tprev) / dt
b1θ, b2θ, b3θ, b4θ, b5θ, b6θ, b7θ = bθs(integ.rs, θ)

ks = integ.ks
if !IIP
if !isinplace(integ)
u = @inbounds integ.uprev +
dt * (
b1θ * ks[1] + b2θ * ks[2] + b3θ * ks[3] + b4θ * ks[4] +
Expand Down
67 changes: 67 additions & 0 deletions test/jet_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
using SimpleDiffEq
using JET
using DiffEqBase
using Test

@testset "JET Static Analysis" begin
# Define test problems
f_scalar(u, p, t) = 1.01 * u
u0_scalar = 1.5
tspan = (0.0, 1.0)
prob_scalar = ODEProblem(f_scalar, u0_scalar, tspan)

function f_iip!(du, u, p, t)
du[1] = 1.01 * u[1]
du[2] = 2.0 * u[2]
end
u0_iip = [1.5, 2.0]
prob_iip = ODEProblem(f_iip!, u0_iip, tspan)

@testset "SimpleEuler type stability" begin
# OOP scalar
integ_oop = DiffEqBase.__init(prob_scalar, SimpleEuler(), dt = 0.1)
rep = JET.report_opt(DiffEqBase.step!, (typeof(integ_oop),))
@test length(JET.get_reports(rep)) == 0

# IIP array
integ_iip = DiffEqBase.__init(prob_iip, SimpleEuler(), dt = 0.1)
rep = JET.report_opt(DiffEqBase.step!, (typeof(integ_iip),))
@test length(JET.get_reports(rep)) == 0
end

@testset "SimpleRK4 type stability" begin
# OOP scalar
integ_oop = DiffEqBase.__init(prob_scalar, SimpleRK4(), dt = 0.1)
rep = JET.report_opt(DiffEqBase.step!, (typeof(integ_oop),))
@test length(JET.get_reports(rep)) == 0

# IIP array
integ_iip = DiffEqBase.__init(prob_iip, SimpleRK4(), dt = 0.1)
rep = JET.report_opt(DiffEqBase.step!, (typeof(integ_iip),))
@test length(JET.get_reports(rep)) == 0
end

@testset "SimpleTsit5 type stability" begin
# OOP scalar
integ_oop = DiffEqBase.__init(prob_scalar, SimpleTsit5(), dt = 0.1)
rep = JET.report_opt(DiffEqBase.step!, (typeof(integ_oop),))
@test length(JET.get_reports(rep)) == 0

# IIP array
integ_iip = DiffEqBase.__init(prob_iip, SimpleTsit5(), dt = 0.1)
rep = JET.report_opt(DiffEqBase.step!, (typeof(integ_iip),))
@test length(JET.get_reports(rep)) == 0
end

@testset "SimpleATsit5 type stability" begin
# OOP scalar
integ_oop = DiffEqBase.__init(prob_scalar, SimpleATsit5(), dt = 0.1)
rep = JET.report_opt(DiffEqBase.step!, (typeof(integ_oop),))
@test length(JET.get_reports(rep)) == 0

# IIP array
integ_iip = DiffEqBase.__init(prob_iip, SimpleATsit5(), dt = 0.1)
rep = JET.report_opt(DiffEqBase.step!, (typeof(integ_iip),))
@test length(JET.get_reports(rep)) == 0
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ using SimpleDiffEq, SafeTestsets, Test

@time begin
@time @safetestset "ExplicitImports Tests" include("explicit_imports_tests.jl")
@time @safetestset "JET Static Analysis Tests" include("jet_tests.jl")
@time @safetestset "Discrete Tests" include("discrete_tests.jl")
@time @safetestset "SimpleEM Tests" include("simpleem_tests.jl")
@time @safetestset "SimpleTsit5 Tests" include("simpletsit5_tests.jl")
Expand Down
Loading