diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml deleted file mode 100644 index 453925c..0000000 --- a/.JuliaFormatter.toml +++ /dev/null @@ -1 +0,0 @@ -style = "sciml" \ No newline at end of file diff --git a/.github/workflows/FormatCheck.yml b/.github/workflows/FormatCheck.yml index c240796..6762c6f 100644 --- a/.github/workflows/FormatCheck.yml +++ b/.github/workflows/FormatCheck.yml @@ -1,13 +1,19 @@ -name: "Format Check" +name: format-check on: push: branches: - 'master' + - 'main' + - 'release-' tags: '*' pull_request: jobs: - format-check: - name: "Format Check" - uses: "SciML/.github/.github/workflows/format-check.yml@v1" + runic: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: fredrikekre/runic-action@v1 + with: + version: '1' diff --git a/src/SimpleDiffEq.jl b/src/SimpleDiffEq.jl index 5786d1f..caa6fd2 100644 --- a/src/SimpleDiffEq.jl +++ b/src/SimpleDiffEq.jl @@ -2,56 +2,56 @@ __precompile__() module SimpleDiffEq -using Reexport: @reexport -using MuladdMacro: @muladd -@reexport using DiffEqBase: DiffEqBase, ODEProblem, SDEProblem, DiscreteProblem, - isinplace, reinit!, u_modified!, ODE_DEFAULT_NORM, - set_t!, solve, step!, init, DESolution, @.., - AbstractODEIntegrator, DEIntegrator, ConstantInterpolation, - __init, __solve, build_solution, has_analytic, - calculate_solution_errors!, is_diagonal_noise, - AbstractSDEAlgorithm, AbstractODEAlgorithm, isdiscrete, SciMLBase -import DiffEqBase.SciMLBase: allows_arbitrary_number_types, allowscomplex, isautodifferentiable, isadaptive -using StaticArrays: SArray, SVector, MVector -using RecursiveArrayTools: recursivecopy! -using LinearAlgebra: mul! -using Parameters: @unpack - -@inline _copy(a::SArray) = a -@inline _copy(a) = copy(a) - -abstract type AbstractSimpleDiffEqODEAlgorithm <: AbstractODEAlgorithm end -isautodifferentiable(alg::AbstractSimpleDiffEqODEAlgorithm) = true -allows_arbitrary_number_types(alg::AbstractSimpleDiffEqODEAlgorithm) = true -allowscomplex(alg::AbstractSimpleDiffEqODEAlgorithm) = true -isadaptive(alg::AbstractSimpleDiffEqODEAlgorithm) = false # except 2, handled individually - -function build_adaptive_controller_cache(::Type{T}) where {T} - beta1 = T(7 / 50) - beta2 = T(2 / 25) - qmax = T(10.0) - qmin = T(1 / 5) - gamma = T(9 / 10) - qoldinit = T(1e-4) - qold = qoldinit - - return beta1, beta2, qmax, qmin, gamma, qoldinit, qold -end - -include("functionmap.jl") -include("euler_maruyama.jl") -include("rk4/rk4.jl") -include("rk4/gpurk4.jl") -include("rk4/looprk4.jl") -include("euler/euler.jl") -include("euler/gpueuler.jl") -include("euler/loopeuler.jl") -include("tsit5/atsit5_cache.jl") -include("tsit5/tsit5.jl") -include("tsit5/atsit5.jl") -include("tsit5/gpuatsit5.jl") -include("verner/verner_tableaus.jl") -include("verner/gpuvern7.jl") -include("verner/gpuvern9.jl") + using Reexport: @reexport + using MuladdMacro: @muladd + @reexport using DiffEqBase: DiffEqBase, ODEProblem, SDEProblem, DiscreteProblem, + isinplace, reinit!, u_modified!, ODE_DEFAULT_NORM, + set_t!, solve, step!, init, DESolution, @.., + AbstractODEIntegrator, DEIntegrator, ConstantInterpolation, + __init, __solve, build_solution, has_analytic, + calculate_solution_errors!, is_diagonal_noise, + AbstractSDEAlgorithm, AbstractODEAlgorithm, isdiscrete, SciMLBase + import DiffEqBase.SciMLBase: allows_arbitrary_number_types, allowscomplex, isautodifferentiable, isadaptive + using StaticArrays: SArray, SVector, MVector + using RecursiveArrayTools: recursivecopy! + using LinearAlgebra: mul! + using Parameters: @unpack + + @inline _copy(a::SArray) = a + @inline _copy(a) = copy(a) + + abstract type AbstractSimpleDiffEqODEAlgorithm <: AbstractODEAlgorithm end + isautodifferentiable(alg::AbstractSimpleDiffEqODEAlgorithm) = true + allows_arbitrary_number_types(alg::AbstractSimpleDiffEqODEAlgorithm) = true + allowscomplex(alg::AbstractSimpleDiffEqODEAlgorithm) = true + isadaptive(alg::AbstractSimpleDiffEqODEAlgorithm) = false # except 2, handled individually + + function build_adaptive_controller_cache(::Type{T}) where {T} + beta1 = T(7 / 50) + beta2 = T(2 / 25) + qmax = T(10.0) + qmin = T(1 / 5) + gamma = T(9 / 10) + qoldinit = T(1.0e-4) + qold = qoldinit + + return beta1, beta2, qmax, qmin, gamma, qoldinit, qold + end + + include("functionmap.jl") + include("euler_maruyama.jl") + include("rk4/rk4.jl") + include("rk4/gpurk4.jl") + include("rk4/looprk4.jl") + include("euler/euler.jl") + include("euler/gpueuler.jl") + include("euler/loopeuler.jl") + include("tsit5/atsit5_cache.jl") + include("tsit5/tsit5.jl") + include("tsit5/atsit5.jl") + include("tsit5/gpuatsit5.jl") + include("verner/verner_tableaus.jl") + include("verner/gpuvern7.jl") + include("verner/gpuvern9.jl") end # module diff --git a/src/euler/euler.jl b/src/euler/euler.jl index f93217c..7c831e5 100644 --- a/src/euler/euler.jl +++ b/src/euler/euler.jl @@ -47,7 +47,7 @@ struct SimpleEuler <: AbstractSimpleDiffEqODEAlgorithm end export SimpleEuler mutable struct SimpleEulerIntegrator{IIP, S, T, P, F} <: - DiffEqBase.AbstractODEIntegrator{SimpleEuler, IIP, S, T} + DiffEqBase.AbstractODEIntegrator{SimpleEuler, IIP, S, T} f::F # ..................................... Equations of motion uprev::S # .......................................... Previous state u::S # ........................................... Current state @@ -71,18 +71,24 @@ DiffEqBase.isinplace(::SEI{IIP}) where {IIP} = IIP # Initialization ################################################################################ -function DiffEqBase.__init(prob::ODEProblem, alg::SimpleEuler; - dt = error("dt is required for this algorithm")) - simpleeuler_init(prob.f, +function DiffEqBase.__init( + prob::ODEProblem, alg::SimpleEuler; + dt = error("dt is required for this algorithm") + ) + return simpleeuler_init( + prob.f, DiffEqBase.isinplace(prob), prob.u0, prob.tspan[1], dt, - prob.p) + prob.p + ) end -function DiffEqBase.__solve(prob::ODEProblem, alg::SimpleEuler; - dt = error("dt is required for this algorithm")) +function DiffEqBase.__solve( + prob::ODEProblem, alg::SimpleEuler; + dt = error("dt is required for this algorithm") + ) u0 = prob.u0 tspan = prob.tspan ts = Array(tspan[1]:dt:tspan[2]) @@ -91,8 +97,10 @@ function DiffEqBase.__solve(prob::ODEProblem, alg::SimpleEuler; @inbounds us[1] = _copy(u0) - integ = simpleeuler_init(prob.f, DiffEqBase.isinplace(prob), prob.u0, - prob.tspan[1], dt, prob.p) + integ = simpleeuler_init( + prob.f, DiffEqBase.isinplace(prob), prob.u0, + prob.tspan[1], dt, prob.p + ) for i in 1:(n - 1) step!(integ) @@ -102,17 +110,22 @@ function DiffEqBase.__solve(prob::ODEProblem, alg::SimpleEuler; sol = DiffEqBase.build_solution(prob, alg, ts, us, calculate_error = false) DiffEqBase.has_analytic(prob.f) && - DiffEqBase.calculate_solution_errors!(sol; - timeseries_errors = true, - dense_errors = false) + DiffEqBase.calculate_solution_errors!( + sol; + timeseries_errors = true, + dense_errors = false + ) return sol end -@inline function simpleeuler_init(f::F, IIP::Bool, u0::S, t0::T, dt::T, - p::P) where - {F, P, T, S} - integ = SEI{IIP, S, T, P, F}(f, +@inline function simpleeuler_init( + f::F, IIP::Bool, u0::S, t0::T, dt::T, + p::P + ) where + {F, P, T, S} + integ = SEI{IIP, S, T, P, F}( + f, _copy(u0), _copy(u0), _copy(u0), @@ -122,7 +135,8 @@ end dt, sign(dt), p, - true) + true + ) return integ end diff --git a/src/euler/gpueuler.jl b/src/euler/gpueuler.jl index c52817c..41e05f7 100644 --- a/src/euler/gpueuler.jl +++ b/src/euler/gpueuler.jl @@ -44,9 +44,11 @@ sol = solve(prob, GPUSimpleEuler(), dt = 0.1) struct GPUSimpleEuler <: AbstractSimpleDiffEqODEAlgorithm end export GPUSimpleEuler -@muladd function DiffEqBase.solve(prob::ODEProblem, +@muladd function DiffEqBase.solve( + prob::ODEProblem, alg::GPUSimpleEuler; - dt = error("dt is required for this algorithm")) + dt = error("dt is required for this algorithm") + ) @assert !isinplace(prob) u0 = prob.u0 tspan = prob.tspan @@ -67,11 +69,15 @@ export GPUSimpleEuler us[i] = u end - sol = DiffEqBase.build_solution(prob, alg, ts, SArray(us), + sol = DiffEqBase.build_solution( + prob, alg, ts, SArray(us), k = nothing, stats = nothing, - calculate_error = false) + calculate_error = false + ) DiffEqBase.has_analytic(prob.f) && - DiffEqBase.calculate_solution_errors!(sol; timeseries_errors = true, - dense_errors = false) + DiffEqBase.calculate_solution_errors!( + sol; timeseries_errors = true, + dense_errors = false + ) sol end diff --git a/src/euler/loopeuler.jl b/src/euler/loopeuler.jl index b6386fd..52d2313 100644 --- a/src/euler/loopeuler.jl +++ b/src/euler/loopeuler.jl @@ -47,15 +47,17 @@ export LoopEuler # Out-of-place # No caching, good for static arrays, bad for arrays -@muladd function DiffEqBase.__solve(prob::ODEProblem{uType, tType, false}, - alg::LoopEuler; - dt = error("dt is required for this algorithm"), - save_everystep = true, - save_start = true, - adaptive = false, - dense = false, - save_end = true, - kwargs...) where {uType, tType} +@muladd function DiffEqBase.__solve( + prob::ODEProblem{uType, tType, false}, + alg::LoopEuler; + dt = error("dt is required for this algorithm"), + save_everystep = true, + save_start = true, + adaptive = false, + dense = false, + save_end = true, + kwargs... + ) where {uType, tType} @assert !adaptive @assert !dense u0 = prob.u0 @@ -90,27 +92,33 @@ export LoopEuler !save_everystep && save_end && (us[end] = u) - sol = DiffEqBase.build_solution(prob, alg, ts, us, + sol = DiffEqBase.build_solution( + prob, alg, ts, us, k = nothing, stats = nothing, - calculate_error = false) + calculate_error = false + ) DiffEqBase.has_analytic(prob.f) && - DiffEqBase.calculate_solution_errors!(sol; timeseries_errors = true, - dense_errors = false) + DiffEqBase.calculate_solution_errors!( + sol; timeseries_errors = true, + dense_errors = false + ) sol end # In-place # Good for mutable objects like arrays # Use DiffEqBase.@.. for simd ivdep -@muladd function DiffEqBase.solve(prob::ODEProblem{uType, tType, true}, - alg::LoopEuler; - dt = error("dt is required for this algorithm"), - save_everystep = true, - save_start = true, - adaptive = false, - dense = false, - save_end = true, - kwargs...) where {uType, tType} +@muladd function DiffEqBase.solve( + prob::ODEProblem{uType, tType, true}, + alg::LoopEuler; + dt = error("dt is required for this algorithm"), + save_everystep = true, + save_start = true, + adaptive = false, + dense = false, + save_end = true, + kwargs... + ) where {uType, tType} @assert !adaptive @assert !dense u0 = prob.u0 @@ -145,11 +153,15 @@ end !save_everystep && save_end && (us[end] = u) - sol = DiffEqBase.build_solution(prob, alg, ts, us, + sol = DiffEqBase.build_solution( + prob, alg, ts, us, k = nothing, stats = nothing, - calculate_error = false) + calculate_error = false + ) DiffEqBase.has_analytic(prob.f) && - DiffEqBase.calculate_solution_errors!(sol; timeseries_errors = true, - dense_errors = false) + DiffEqBase.calculate_solution_errors!( + sol; timeseries_errors = true, + dense_errors = false + ) sol end diff --git a/src/euler_maruyama.jl b/src/euler_maruyama.jl index 39d9a02..13b13a5 100644 --- a/src/euler_maruyama.jl +++ b/src/euler_maruyama.jl @@ -34,10 +34,14 @@ sol = solve(prob, SimpleEM(), dt = 0.01) struct SimpleEM <: DiffEqBase.AbstractSDEAlgorithm end export SimpleEM -@muladd function DiffEqBase.solve(prob::SDEProblem{uType, tType, false}, alg::SimpleEM, - args...; - dt = error("dt required for SimpleEM")) where {uType, - tType} +@muladd function DiffEqBase.solve( + prob::SDEProblem{uType, tType, false}, alg::SimpleEM, + args...; + dt = error("dt required for SimpleEM") + ) where { + uType, + tType, + } f = prob.f g = prob.g u0 = prob.u0 @@ -60,25 +64,31 @@ export SimpleEM if is_diagonal_noise if u0 isa Number u[i] = uprev + f(uprev, p, tprev) * dt + - sqdt * g(uprev, p, tprev) * randn(typeof(u0)) + sqdt * g(uprev, p, tprev) * randn(typeof(u0)) else u[i] = uprev + f(uprev, p, tprev) * dt + - sqdt * g(uprev, p, tprev) .* randn(typeof(u0)) + sqdt * g(uprev, p, tprev) .* randn(typeof(u0)) end else u[i] = uprev + f(uprev, p, tprev) * dt + - sqdt * g(uprev, p, tprev) * randn(size(prob.noise_rate_prototype, 2)) + sqdt * g(uprev, p, tprev) * randn(size(prob.noise_rate_prototype, 2)) end end - sol = DiffEqBase.build_solution(prob, alg, t, u, - calculate_error = false) + sol = DiffEqBase.build_solution( + prob, alg, t, u, + calculate_error = false + ) end -@muladd function DiffEqBase.solve(prob::SDEProblem{uType, tType, true}, alg::SimpleEM, - args...; - dt = error("dt required for SimpleEM")) where {uType, - tType} +@muladd function DiffEqBase.solve( + prob::SDEProblem{uType, tType, true}, alg::SimpleEM, + args...; + dt = error("dt required for SimpleEM") + ) where { + uType, + tType, + } f = prob.f g = prob.g u0 = prob.u0 @@ -88,7 +98,7 @@ end gtmp = DiffEqBase.is_diagonal_noise(prob) ? zero(u0) : zero(prob.noise_rate_prototype) gtmp2 = DiffEqBase.is_diagonal_noise(prob) ? nothing : zero(u0) dW = DiffEqBase.is_diagonal_noise(prob) ? zero(u0) : - false .* prob.noise_rate_prototype[1, :] + false .* prob.noise_rate_prototype[1, :] @inbounds begin n = Int((tspan[2] - tspan[1]) / dt) + 1 @@ -112,6 +122,8 @@ end end end - sol = DiffEqBase.build_solution(prob, alg, t, u, - calculate_error = false) + sol = DiffEqBase.build_solution( + prob, alg, t, u, + calculate_error = false + ) end diff --git a/src/functionmap.jl b/src/functionmap.jl index 360d545..afc668a 100644 --- a/src/functionmap.jl +++ b/src/functionmap.jl @@ -37,9 +37,11 @@ export SimpleFunctionMap SciMLBase.isdiscrete(alg::SimpleFunctionMap) = true # ConstantCache version -function DiffEqBase.__solve(prob::DiffEqBase.DiscreteProblem{uType, tupType, false}, +function DiffEqBase.__solve( + prob::DiffEqBase.DiscreteProblem{uType, tupType, false}, alg::SimpleFunctionMap; - calculate_values = true) where {uType, tupType} + calculate_values = true + ) where {uType, tupType} tType = eltype(tupType) tspan = prob.tspan f = prob.f @@ -55,15 +57,19 @@ function DiffEqBase.__solve(prob::DiffEqBase.DiscreteProblem{uType, tupType, fal u[i] = f(u[i - 1], p, t[i]) end end - sol = DiffEqBase.build_solution(prob, alg, t, u, dense = false, + return sol = DiffEqBase.build_solution( + prob, alg, t, u, dense = false, interp = DiffEqBase.ConstantInterpolation(t, u), - calculate_error = false) + calculate_error = false + ) end # Cache version -function DiffEqBase.__solve(prob::DiscreteProblem{uType, tupType, true}, +function DiffEqBase.__solve( + prob::DiscreteProblem{uType, tupType, true}, alg::SimpleFunctionMap; - calculate_values = true) where {uType, tupType} + calculate_values = true + ) where {uType, tupType} tType = eltype(tupType) tspan = prob.tspan f = prob.f @@ -80,16 +86,18 @@ function DiffEqBase.__solve(prob::DiscreteProblem{uType, tupType, true}, f(u[i], u[i - 1], p, t[i]) end end - sol = DiffEqBase.build_solution(prob, alg, t, u, dense = false, + return sol = DiffEqBase.build_solution( + prob, alg, t, u, dense = false, interp = DiffEqBase.ConstantInterpolation(t, u), - calculate_error = false) + calculate_error = false + ) end ################################################## # Integrator version mutable struct DiscreteIntegrator{F, IIP, uType, tType, P, S} <: - DiffEqBase.DEIntegrator{SimpleFunctionMap, IIP, uType, tType} + DiffEqBase.DEIntegrator{SimpleFunctionMap, IIP, uType, tType} f::F u::uType t::tType @@ -100,8 +108,10 @@ mutable struct DiscreteIntegrator{F, IIP, uType, tType, P, S} <: tdir::tType end -function DiffEqBase.__init(prob::DiscreteProblem, - alg::SimpleFunctionMap) +function DiffEqBase.__init( + prob::DiscreteProblem, + alg::SimpleFunctionMap + ) sol = DiffEqBase.__solve(prob, alg; calculate_values = false) F = typeof(prob.f) IIP = isinplace(prob) @@ -109,9 +119,11 @@ function DiffEqBase.__init(prob::DiscreteProblem, tType = typeof(prob.tspan[1]) P = typeof(prob.p) S = typeof(sol) - DiscreteIntegrator{F, IIP, uType, tType, P, S}(prob.f, prob.u0, prob.tspan[1], + return DiscreteIntegrator{F, IIP, uType, tType, P, S}( + prob.f, prob.u0, prob.tspan[1], copy(prob.u0), prob.p, sol, 1, - one(tType)) + one(tType) + ) end function DiffEqBase.step!(integrator::DiscreteIntegrator) @@ -123,7 +135,7 @@ function DiffEqBase.step!(integrator::DiscreteIntegrator) f = integrator.f i = integrator.i - if isinplace(integrator.sol.prob) + return if isinplace(integrator.sol.prob) f(integrator.sol.u[i], uprev, p, i) integrator.uprev = integrator.u integrator.u = integrator.sol.u[i] diff --git a/src/rk4/gpurk4.jl b/src/rk4/gpurk4.jl index 2104e43..e7985a3 100644 --- a/src/rk4/gpurk4.jl +++ b/src/rk4/gpurk4.jl @@ -44,9 +44,11 @@ sol = solve(prob, GPUSimpleRK4(), dt = 0.1) struct GPUSimpleRK4 <: AbstractSimpleDiffEqODEAlgorithm end export GPUSimpleRK4 -@muladd function DiffEqBase.solve(prob::ODEProblem, +@muladd function DiffEqBase.solve( + prob::ODEProblem, alg::GPUSimpleRK4; - dt = error("dt is required for this algorithm")) + dt = error("dt is required for this algorithm") + ) @assert !isinplace(prob) u0 = prob.u0 tspan = prob.tspan @@ -75,11 +77,15 @@ export GPUSimpleRK4 us[i] = u end - sol = DiffEqBase.build_solution(prob, alg, ts, SArray(us), + sol = DiffEqBase.build_solution( + prob, alg, ts, SArray(us), k = nothing, stats = nothing, - calculate_error = false) + calculate_error = false + ) DiffEqBase.has_analytic(prob.f) && - DiffEqBase.calculate_solution_errors!(sol; timeseries_errors = true, - dense_errors = false) + DiffEqBase.calculate_solution_errors!( + sol; timeseries_errors = true, + dense_errors = false + ) sol end diff --git a/src/rk4/looprk4.jl b/src/rk4/looprk4.jl index 1d45355..34fe9ab 100644 --- a/src/rk4/looprk4.jl +++ b/src/rk4/looprk4.jl @@ -47,15 +47,17 @@ export LoopRK4 # Out-of-place # No caching, good for static arrays, bad for arrays -@muladd function DiffEqBase.__solve(prob::ODEProblem{uType, tType, false}, - alg::LoopRK4; - dt = error("dt is required for this algorithm"), - save_everystep = true, - save_start = true, - adaptive = false, - dense = false, - save_end = true, - kwargs...) where {uType, tType} +@muladd function DiffEqBase.__solve( + prob::ODEProblem{uType, tType, false}, + alg::LoopRK4; + dt = error("dt is required for this algorithm"), + save_everystep = true, + save_start = true, + adaptive = false, + dense = false, + save_end = true, + kwargs... + ) where {uType, tType} @assert !adaptive @assert !dense u0 = prob.u0 @@ -98,27 +100,33 @@ export LoopRK4 !save_everystep && save_end && (us[end] = u) - sol = DiffEqBase.build_solution(prob, alg, ts, us, + sol = DiffEqBase.build_solution( + prob, alg, ts, us, k = nothing, stats = nothing, - calculate_error = false) + calculate_error = false + ) DiffEqBase.has_analytic(prob.f) && - DiffEqBase.calculate_solution_errors!(sol; timeseries_errors = true, - dense_errors = false) + DiffEqBase.calculate_solution_errors!( + sol; timeseries_errors = true, + dense_errors = false + ) sol end # In-place # Good for mutable objects like arrays # Use DiffEqBase.@.. for simd ivdep -@muladd function DiffEqBase.solve(prob::ODEProblem{uType, tType, true}, - alg::LoopRK4; - dt = error("dt is required for this algorithm"), - save_everystep = true, - save_start = true, - adaptive = false, - dense = false, - save_end = true, - kwargs...) where {uType, tType} +@muladd function DiffEqBase.solve( + prob::ODEProblem{uType, tType, true}, + alg::LoopRK4; + dt = error("dt is required for this algorithm"), + save_everystep = true, + save_start = true, + adaptive = false, + dense = false, + save_end = true, + kwargs... + ) where {uType, tType} @assert !adaptive @assert !dense u0 = prob.u0 @@ -166,11 +174,15 @@ end !save_everystep && save_end && (us[end] = u) - sol = DiffEqBase.build_solution(prob, alg, ts, us, + sol = DiffEqBase.build_solution( + prob, alg, ts, us, k = nothing, stats = nothing, - calculate_error = false) + calculate_error = false + ) DiffEqBase.has_analytic(prob.f) && - DiffEqBase.calculate_solution_errors!(sol; timeseries_errors = true, - dense_errors = false) + DiffEqBase.calculate_solution_errors!( + sol; timeseries_errors = true, + dense_errors = false + ) sol end diff --git a/src/rk4/rk4.jl b/src/rk4/rk4.jl index fb122da..e1ffbf8 100644 --- a/src/rk4/rk4.jl +++ b/src/rk4/rk4.jl @@ -51,7 +51,7 @@ struct SimpleRK4 <: AbstractSimpleDiffEqODEAlgorithm end export SimpleRK4 mutable struct SimpleRK4Integrator{IIP, S, T, P, F} <: - DiffEqBase.AbstractODEIntegrator{SimpleRK4, IIP, S, T} + DiffEqBase.AbstractODEIntegrator{SimpleRK4, IIP, S, T} f::F # ..................................... Equations of motion uprev::S # .......................................... Previous state u::S # ........................................... Current state @@ -76,18 +76,24 @@ DiffEqBase.isinplace(::SRK4{IIP}) where {IIP} = IIP # Initialization ################################################################################ -function DiffEqBase.__init(prob::ODEProblem, alg::SimpleRK4; - dt = error("dt is required for this algorithm")) - simplerk4_init(prob.f, +function DiffEqBase.__init( + prob::ODEProblem, alg::SimpleRK4; + dt = error("dt is required for this algorithm") + ) + return simplerk4_init( + prob.f, DiffEqBase.isinplace(prob), prob.u0, prob.tspan[1], dt, - prob.p) + prob.p + ) end -function DiffEqBase.__solve(prob::ODEProblem, alg::SimpleRK4; - dt = error("dt is required for this algorithm")) +function DiffEqBase.__solve( + prob::ODEProblem, alg::SimpleRK4; + dt = error("dt is required for this algorithm") + ) u0 = prob.u0 tspan = prob.tspan ts = Array(tspan[1]:dt:tspan[2]) @@ -96,8 +102,10 @@ function DiffEqBase.__solve(prob::ODEProblem, alg::SimpleRK4; @inbounds us[1] = _copy(u0) - integ = simplerk4_init(prob.f, DiffEqBase.isinplace(prob), prob.u0, - prob.tspan[1], dt, prob.p) + integ = simplerk4_init( + prob.f, DiffEqBase.isinplace(prob), prob.u0, + prob.tspan[1], dt, prob.p + ) # FSAL for i in 1:(n - 1) @@ -108,21 +116,26 @@ function DiffEqBase.__solve(prob::ODEProblem, alg::SimpleRK4; sol = DiffEqBase.build_solution(prob, alg, ts, us, calculate_error = false) DiffEqBase.has_analytic(prob.f) && - DiffEqBase.calculate_solution_errors!(sol; - timeseries_errors = true, - dense_errors = false) + DiffEqBase.calculate_solution_errors!( + sol; + timeseries_errors = true, + dense_errors = false + ) return sol end -@inline function simplerk4_init(f::F, IIP::Bool, u0::S, t0::T, dt::T, - p::P) where - {F, P, T, S} +@inline function simplerk4_init( + f::F, IIP::Bool, u0::S, t0::T, dt::T, + p::P + ) where + {F, P, T, S} # Allocate the vector with the interpolants. For RK4, we need 5. ks = [zero(u0) for _ in 1:5] - integ = SRK4{IIP, S, T, P, F}(f, + integ = SRK4{IIP, S, T, P, F}( + f, _copy(u0), _copy(u0), _copy(u0), @@ -133,7 +146,8 @@ end sign(dt), p, true, - ks) + ks + ) return integ end @@ -254,18 +268,22 @@ end # Hermite interpolation. @inbounds if !isinplace(integ) u = (1 - Θ) * y₀ + Θ * y₁ + - Θ * (Θ - 1) * ((1 - 2Θ) * (y₁ - y₀) + - (Θ - 1) * dt * ks[1] + - Θ * dt * ks[5]) + Θ * (Θ - 1) * ( + (1 - 2Θ) * (y₁ - y₀) + + (Θ - 1) * dt * ks[1] + + Θ * dt * ks[5] + ) return u else u = similar(y₁) for i in 1:length(u) u[i] = (1 - Θ) * y₀[i] + Θ * y₁[i] + - Θ * (Θ - 1) * - ((1 - 2Θ) * (y₁[i] - y₀[i]) + + Θ * (Θ - 1) * + ( + (1 - 2Θ) * (y₁[i] - y₀[i]) + (Θ - 1) * dt * ks[1][i] + - Θ * dt * ks[5][i]) + Θ * dt * ks[5][i] + ) end return u diff --git a/src/tsit5/atsit5.jl b/src/tsit5/atsit5.jl index 594a3ed..8589812 100644 --- a/src/tsit5/atsit5.jl +++ b/src/tsit5/atsit5.jl @@ -56,10 +56,10 @@ const beta2 = 2 / 25 const qmax = 10.0 const qmin = 1 / 5 const gamma = 9 / 10 -const qoldinit = 1e-4 +const qoldinit = 1.0e-4 mutable struct SimpleATsit5Integrator{IIP, S, T, P, F, N} <: - DiffEqBase.AbstractODEIntegrator{SimpleATsit5, IIP, S, T} + DiffEqBase.AbstractODEIntegrator{SimpleATsit5, IIP, S, T} f::F # eom uprev::S # previous state u::S # current state @@ -91,19 +91,25 @@ DiffEqBase.u_modified!(i::SAT5I, bool) = (i.u_modified = bool) ####################################################################################### # Initialization ####################################################################################### -function DiffEqBase.__init(prob::ODEProblem, alg::SimpleATsit5; +function DiffEqBase.__init( + prob::ODEProblem, alg::SimpleATsit5; dt = 0.1, - abstol = 1e-6, reltol = 1e-3, - internalnorm = DiffEqBase.ODE_DEFAULT_NORM, kwargs...) - simpleatsit5_init(prob.f, DiffEqBase.isinplace(prob), prob.u0, + abstol = 1.0e-6, reltol = 1.0e-3, + internalnorm = DiffEqBase.ODE_DEFAULT_NORM, kwargs... + ) + return simpleatsit5_init( + prob.f, DiffEqBase.isinplace(prob), prob.u0, prob.tspan[1], prob.tspan[2], dt, prob.p, abstol, reltol, - internalnorm) + internalnorm + ) end -function DiffEqBase.__solve(prob::ODEProblem, alg::SimpleATsit5; +function DiffEqBase.__solve( + prob::ODEProblem, alg::SimpleATsit5; dt = 0.1, saveat = nothing, save_everystep = true, - abstol = 1e-6, reltol = 1e-3, - internalnorm = DiffEqBase.ODE_DEFAULT_NORM) + abstol = 1.0e-6, reltol = 1.0e-3, + internalnorm = DiffEqBase.ODE_DEFAULT_NORM + ) u0 = prob.u0 tspan = prob.tspan ts = Vector{eltype(dt)}(undef, 1) @@ -123,8 +129,10 @@ function DiffEqBase.__solve(prob::ODEProblem, alg::SimpleATsit5; end end - integ = simpleatsit5_init(prob.f, DiffEqBase.isinplace(prob), prob.u0, - tspan[1], tspan[2], dt, prob.p, abstol, reltol, internalnorm) + integ = simpleatsit5_init( + prob.f, DiffEqBase.isinplace(prob), prob.u0, + tspan[1], tspan[2], dt, prob.p, abstol, reltol, internalnorm + ) # FSAL while integ.t < tspan[2] step!(integ) @@ -148,25 +156,33 @@ function DiffEqBase.__solve(prob::ODEProblem, alg::SimpleATsit5; push!(ts, integ.t) end - sol = DiffEqBase.build_solution(prob, alg, ts, us, - calculate_error = false) + sol = DiffEqBase.build_solution( + prob, alg, ts, us, + calculate_error = false + ) DiffEqBase.has_analytic(prob.f) && - DiffEqBase.calculate_solution_errors!(sol; timeseries_errors = true, - dense_errors = false) - sol + DiffEqBase.calculate_solution_errors!( + sol; timeseries_errors = true, + dense_errors = false + ) + return sol end -@inline function simpleatsit5_init(f::F, +@inline function simpleatsit5_init( + f::F, IIP::Bool, u0::S, t0::T, tf::T, dt::T, p::P, abstol, reltol, - internalnorm::N) where {F, P, S, T, N} + internalnorm::N + ) where {F, P, S, T, N} cs, as, btildes, rs = _build_atsit5_caches(T) ks = _initialize_ks(u0) - integ = SAT5I{IIP, S, T, P, F, N}(f, recursivecopy(u0), recursivecopy(u0), + return integ = SAT5I{IIP, S, T, P, F, N}( + f, recursivecopy(u0), recursivecopy(u0), recursivecopy(u0), t0, t0, t0, tf, dt, dt, sign(tf - t0), p, true, ks, cs, as, btildes, rs, - qoldinit, abstol, reltol, internalnorm) + qoldinit, abstol, reltol, internalnorm + ) end @inline _initialize_ks(u0::AbstractArray{T}) where {T <: Number} = [zero(u0) for i in 1:7] @@ -174,7 +190,7 @@ end return [[zero(u0[j]) for j in 1:length(u0)] for i in 1:7] end @inline function _initialize_ks(u0::T) where {T <: Number} - [zero(u0) for i in 1:7] + return [zero(u0) for i in 1:7] end ####################################################################################### @@ -189,7 +205,7 @@ end p = integ.p tf = integ.tf a21, a31, a32, a41, a42, a43, a51, a52, a53, a54, - a61, a62, a63, a64, a65, a71, a72, a73, a74, a75, a76 = integ.as + a61, a62, a63, a64, a65, a71, a72, a73, a74, a75, a76 = integ.as btilde1, btilde2, btilde3, btilde4, btilde5, btilde6, btilde7 = integ.btildes k1, k2, k3, k4, k5, k6, k7 = integ.ks @@ -214,7 +230,7 @@ end EEst = Inf @inbounds while EEst > 1 - dt < 1e-14 && error("dt 1 - dt < 1e-14 && error("dt 1 - dt < 1e-14 && error("dt 1 - dt < 1e-14 && error("dt 1 - dt < 1e-14 && error("dt 1 - dt < 1e-14 && error("dt norm(u[:, step!(iip); step!(iip); -@test oop.u≈iip.u atol=1e-5 -@test oop.t≈iip.t atol=1e-5 +@test oop.u ≈ iip.u atol = 1.0e-5 +@test oop.t ≈ iip.t atol = 1.0e-5 ################################################################################### # VectorVector test: @@ -157,12 +161,16 @@ function vviip(du, u, p, t) # takes Vector{Vector} end ran = rand(3) -odevoop = ODEProblem{true}(vvoop, [SVector{3}(u0), SVector{3}(ran)], (0.0, 100.0), - [10, 28, 8 / 3]) +odevoop = ODEProblem{true}( + vvoop, [SVector{3}(u0), SVector{3}(ran)], (0.0, 100.0), + [10, 28, 8 / 3] +) odeviip = ODEProblem{true}(vviip, [u0, ran], (0.0, 100.0), [10, 28, 8 / 3]) -viip = init(odeviip, SimpleATsit5(), dt = dt; - internalnorm = (u, t) -> DiffEqBase.ODE_DEFAULT_NORM(u[1], t)) +viip = init( + odeviip, SimpleATsit5(), dt = dt; + internalnorm = (u, t) -> DiffEqBase.ODE_DEFAULT_NORM(u[1], t) +) step!(viip); step!(viip); @@ -170,10 +178,12 @@ iip = init(odeiip, SimpleATsit5(), dt = dt) step!(iip); step!(iip); -@test iip.u≈viip.u[1] atol=1e-5 +@test iip.u ≈ viip.u[1] atol = 1.0e-5 -voop = init(odevoop, SimpleATsit5(), dt = dt, - internalnorm = (u, t) -> DiffEqBase.ODE_DEFAULT_NORM(u[1], t)) +voop = init( + odevoop, SimpleATsit5(), dt = dt, + internalnorm = (u, t) -> DiffEqBase.ODE_DEFAULT_NORM(u[1], t) +) step!(voop); step!(voop); @@ -181,10 +191,10 @@ oop = init(odeoop, SimpleATsit5(), dt = dt) step!(oop); step!(oop); -@test oop.u≈voop.u[1] atol=1e-5 +@test oop.u ≈ voop.u[1] atol = 1.0e-5 # Final test that the states of both methods should be the same: -@test voop.u[2]≈viip.u[2] atol=1e-5 +@test voop.u[2] ≈ viip.u[2] atol = 1.0e-5 # viip = init(odeviip,SimpleATsit5(),dt=dt; internalnorm = u -> SimpleDiffEq.defaultnorm(u[1])) # step!(viip); step!(viip) @@ -208,11 +218,11 @@ using SimpleDiffEq, OrdinaryDiffEq, StaticArrays dx2 = f * cos(ω * t) - β * x[1] - x[1]^3 - d * x[2] return SVector(dx1, dx2) end -prob = ODEProblem(duffing_rule, SVector(0.1, 0.2), (0.0, 1e12), [0.3, 0.1, 0.2, -1]) +prob = ODEProblem(duffing_rule, SVector(0.1, 0.2), (0.0, 1.0e12), [0.3, 0.1, 0.2, -1]) T = 2π / 0.3 # this is the period of the oscillator -integ2 = init(prob, SimpleATsit5(), reltol = 1e-9) +integ2 = init(prob, SimpleATsit5(), reltol = 1.0e-9) step!(integ2, T * 20, true) v = zeros(200, 2) for k in 1:200 @@ -230,23 +240,23 @@ hence changing the final tspan to test with save_everystep = false odeiip = remake(odeiip; tspan = (0.0, 10.0)) odeoop = remake(odeoop; tspan = (0.0, 10.0)) -sol = solve(odeiip, Tsit5(), reltol = 1e-9, abstol = 1e-9, save_everystep = false) -sol1 = solve(odeiip, SimpleATsit5(), reltol = 1e-9, abstol = 1e-9, save_everystep = false) +sol = solve(odeiip, Tsit5(), reltol = 1.0e-9, abstol = 1.0e-9, save_everystep = false) +sol1 = solve(odeiip, SimpleATsit5(), reltol = 1.0e-9, abstol = 1.0e-9, save_everystep = false) -@test sol.u≈sol1.u atol=1e-5 +@test sol.u ≈ sol1.u atol = 1.0e-5 @test sol.t ≈ sol1.t -sol = solve(odeoop, Tsit5(), reltol = 1e-9, abstol = 1e-9, save_everystep = false) -sol1 = solve(odeoop, SimpleATsit5(), reltol = 1e-9, abstol = 1e-9, save_everystep = false) +sol = solve(odeoop, Tsit5(), reltol = 1.0e-9, abstol = 1.0e-9, save_everystep = false) +sol1 = solve(odeoop, SimpleATsit5(), reltol = 1.0e-9, abstol = 1.0e-9, save_everystep = false) -@test sol.u≈sol1.u atol=1e-5 +@test sol.u ≈ sol1.u atol = 1.0e-5 @test sol.t ≈ sol1.t simple_f(u, p, t) = 1.01 * u u0 = 1 / 2 tspan = (0.0, 1.0) prob = ODEProblem(simple_f, u0, tspan) -sol = solve(prob, Tsit5(), reltol = 1e-9, abstol = 1e-9, save_everystep = false) -sol1 = solve(prob, SimpleATsit5(), reltol = 1e-9, abstol = 1e-9, save_everystep = false) -@test sol.u≈sol1.u atol=1e-5 +sol = solve(prob, Tsit5(), reltol = 1.0e-9, abstol = 1.0e-9, save_everystep = false) +sol1 = solve(prob, SimpleATsit5(), reltol = 1.0e-9, abstol = 1.0e-9, save_everystep = false) +@test sol.u ≈ sol1.u atol = 1.0e-5 @test sol.t ≈ sol1.t diff --git a/test/simpleem_tests.jl b/test/simpleem_tests.jl index d5dd1b5..776a9c5 100644 --- a/test/simpleem_tests.jl +++ b/test/simpleem_tests.jl @@ -64,7 +64,7 @@ function g_oop(du, u, p, t) du[2, 1] = 1.2u[2] du[2, 2] = 0.2u[2] du[2, 3] = 0.3u[2] - du[2, 4] = 1.8u[2] + return du[2, 4] = 1.8u[2] end prob = SDEProblem(f_oop, g_oop, ones(2), (0.0, 1.0), noise_rate_prototype = zeros(2, 4)) diff --git a/test/simpleeuler_tests.jl b/test/simpleeuler_tests.jl index 0205352..5835313 100644 --- a/test/simpleeuler_tests.jl +++ b/test/simpleeuler_tests.jl @@ -43,12 +43,14 @@ end u0 = 10ones(3) dt = 0.01 -oop = SimpleDiffEq.simpleeuler_init(loop, +oop = SimpleDiffEq.simpleeuler_init( + loop, false, SVector{3}(u0), 0.0, dt, - [10, 28, 8 / 3]) + [10, 28, 8 / 3] +) step!(oop) for i in 1:10000 @@ -62,12 +64,14 @@ end # In-place version of the algorithm # --------------------------------- -iip = SimpleDiffEq.simpleeuler_init(liip, +iip = SimpleDiffEq.simpleeuler_init( + liip, true, copy(u0), 0.0, dt, - [10, 28, 8 / 3]) + [10, 28, 8 / 3] +) step!(iip) @@ -88,20 +92,24 @@ end u0 = 10ones(3) dt = 0.01 -odeoop = ODEProblem{false}(loop, +odeoop = ODEProblem{false}( + loop, SVector{3}(u0), (0.0, 100.0), - [10, 28, 8 / 3]) + [10, 28, 8 / 3] +) oop = init(odeoop, SimpleEuler(), dt = dt) step!(oop) step!(oop) -deoop = DiffEqBase.init(odeoop, +deoop = DiffEqBase.init( + odeoop, Euler(); adaptive = false, save_everystep = false, - dt = dt) + dt = dt +) step!(deoop) step!(deoop) @@ -118,20 +126,24 @@ sol = solve(odeoop, LoopEuler(), dt = dt) # In-place version of the algorithm # --------------------------------- -odeiip = ODEProblem{true}(liip, +odeiip = ODEProblem{true}( + liip, u0, (0.0, 100.0), - [10, 28, 8 / 3]) + [10, 28, 8 / 3] +) iip = init(odeiip, SimpleEuler(), dt = dt) step!(iip) step!(iip) -deiip = DiffEqBase.init(odeiip, +deiip = DiffEqBase.init( + odeiip, Euler(); adaptive = false, save_everystep = false, - dt = dt) + dt = dt +) step!(deiip) step!(deiip) diff --git a/test/simplerk4_tests.jl b/test/simplerk4_tests.jl index 49686bc..387be89 100644 --- a/test/simplerk4_tests.jl +++ b/test/simplerk4_tests.jl @@ -43,12 +43,14 @@ end u0 = 10ones(3) dt = 0.01 -oop = SimpleDiffEq.simplerk4_init(loop, +oop = SimpleDiffEq.simplerk4_init( + loop, false, SVector{3}(u0), 0.0, dt, - [10, 28, 8 / 3]) + [10, 28, 8 / 3] +) step!(oop) for i in 1:10000 @@ -62,12 +64,14 @@ end # In-place version of the algorithm # --------------------------------- -iip = SimpleDiffEq.simplerk4_init(liip, +iip = SimpleDiffEq.simplerk4_init( + liip, true, copy(u0), 0.0, dt, - [10, 28, 8 / 3]) + [10, 28, 8 / 3] +) step!(iip) @@ -88,20 +92,24 @@ end u0 = 10ones(3) dt = 0.01 -odeoop = ODEProblem{false}(loop, +odeoop = ODEProblem{false}( + loop, SVector{3}(u0), (0.0, 100.0), - [10, 28, 8 / 3]) + [10, 28, 8 / 3] +) oop = init(odeoop, SimpleRK4(), dt = dt) step!(oop) step!(oop) -deoop = DiffEqBase.init(odeoop, +deoop = DiffEqBase.init( + odeoop, RK4(); adaptive = false, save_everystep = false, - dt = dt) + dt = dt +) step!(deoop) step!(deoop) @@ -118,20 +126,24 @@ sol = solve(odeoop, LoopRK4(), dt = dt) # In-place version of the algorithm # --------------------------------- -odeiip = ODEProblem{true}(liip, +odeiip = ODEProblem{true}( + liip, u0, (0.0, 100.0), - [10, 28, 8 / 3]) + [10, 28, 8 / 3] +) iip = init(odeiip, SimpleRK4(), dt = dt) step!(iip) step!(iip) -deiip = DiffEqBase.init(odeiip, +deiip = DiffEqBase.init( + odeiip, RK4(); adaptive = false, save_everystep = false, - dt = dt) + dt = dt +) step!(deiip) step!(deiip) diff --git a/test/simpletsit5_tests.jl b/test/simpletsit5_tests.jl index 99dcf46..7e24616 100644 --- a/test/simpletsit5_tests.jl +++ b/test/simpletsit5_tests.jl @@ -55,18 +55,22 @@ iip = init(odeiip, SimpleTsit5(), dt = dt) step!(iip); step!(iip); -deoop = DiffEqBase.init(odeoop, Tsit5(); adaptive = false, - save_everystep = false, dt = dt) +deoop = DiffEqBase.init( + odeoop, Tsit5(); adaptive = false, + save_everystep = false, dt = dt +) step!(deoop); step!(deoop); @test oop.u == deoop.u -deiip = DiffEqBase.init(odeiip, Tsit5(); +deiip = DiffEqBase.init( + odeiip, Tsit5(); adaptive = false, save_everystep = false, - dt = dt) + dt = dt +) step!(deiip); step!(deiip); -@test iip.u≈deiip.u atol=1e-14 +@test iip.u ≈ deiip.u atol = 1.0e-14 sol = solve(odeoop, SimpleTsit5(), dt = dt) @@ -82,7 +86,7 @@ function ode(x, p, t) return ([dx]) end prob = ODEProblem(ode, [1.0], (0.0, 0.05), nothing) -sol = solve(prob, SimpleTsit5(), dt = 0.05/11) # On my PC, the integration ends at 0.04545... +sol = solve(prob, SimpleTsit5(), dt = 0.05 / 11) # On my PC, the integration ends at 0.04545... @test sol.t[end] == 0.05 #=