Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .JuliaFormatter.toml

This file was deleted.

14 changes: 10 additions & 4 deletions .github/workflows/FormatCheck.yml
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
name: "Format Check"
name: format-check

on:
push:
branches:
- 'master'
- 'main'
- 'release-'
tags: '*'
pull_request:

jobs:
format-check:
name: "Format Check"
uses: "SciML/.github/.github/workflows/format-check.yml@v1"
runic:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: fredrikekre/runic-action@v1
with:
version: '1'
102 changes: 51 additions & 51 deletions src/SimpleDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,56 +2,56 @@ __precompile__()

module SimpleDiffEq

using Reexport: @reexport
using MuladdMacro: @muladd
@reexport using DiffEqBase: DiffEqBase, ODEProblem, SDEProblem, DiscreteProblem,
isinplace, reinit!, u_modified!, ODE_DEFAULT_NORM,
set_t!, solve, step!, init, DESolution, @..,
AbstractODEIntegrator, DEIntegrator, ConstantInterpolation,
__init, __solve, build_solution, has_analytic,
calculate_solution_errors!, is_diagonal_noise,
AbstractSDEAlgorithm, AbstractODEAlgorithm, isdiscrete, SciMLBase
import DiffEqBase.SciMLBase: allows_arbitrary_number_types, allowscomplex, isautodifferentiable, isadaptive
using StaticArrays: SArray, SVector, MVector
using RecursiveArrayTools: recursivecopy!
using LinearAlgebra: mul!
using Parameters: @unpack

@inline _copy(a::SArray) = a
@inline _copy(a) = copy(a)

abstract type AbstractSimpleDiffEqODEAlgorithm <: AbstractODEAlgorithm end
isautodifferentiable(alg::AbstractSimpleDiffEqODEAlgorithm) = true
allows_arbitrary_number_types(alg::AbstractSimpleDiffEqODEAlgorithm) = true
allowscomplex(alg::AbstractSimpleDiffEqODEAlgorithm) = true
isadaptive(alg::AbstractSimpleDiffEqODEAlgorithm) = false # except 2, handled individually

function build_adaptive_controller_cache(::Type{T}) where {T}
beta1 = T(7 / 50)
beta2 = T(2 / 25)
qmax = T(10.0)
qmin = T(1 / 5)
gamma = T(9 / 10)
qoldinit = T(1e-4)
qold = qoldinit

return beta1, beta2, qmax, qmin, gamma, qoldinit, qold
end

include("functionmap.jl")
include("euler_maruyama.jl")
include("rk4/rk4.jl")
include("rk4/gpurk4.jl")
include("rk4/looprk4.jl")
include("euler/euler.jl")
include("euler/gpueuler.jl")
include("euler/loopeuler.jl")
include("tsit5/atsit5_cache.jl")
include("tsit5/tsit5.jl")
include("tsit5/atsit5.jl")
include("tsit5/gpuatsit5.jl")
include("verner/verner_tableaus.jl")
include("verner/gpuvern7.jl")
include("verner/gpuvern9.jl")
using Reexport: @reexport
using MuladdMacro: @muladd
@reexport using DiffEqBase: DiffEqBase, ODEProblem, SDEProblem, DiscreteProblem,
isinplace, reinit!, u_modified!, ODE_DEFAULT_NORM,
set_t!, solve, step!, init, DESolution, @..,
AbstractODEIntegrator, DEIntegrator, ConstantInterpolation,
__init, __solve, build_solution, has_analytic,
calculate_solution_errors!, is_diagonal_noise,
AbstractSDEAlgorithm, AbstractODEAlgorithm, isdiscrete, SciMLBase
import DiffEqBase.SciMLBase: allows_arbitrary_number_types, allowscomplex, isautodifferentiable, isadaptive
using StaticArrays: SArray, SVector, MVector
using RecursiveArrayTools: recursivecopy!
using LinearAlgebra: mul!
using Parameters: @unpack

@inline _copy(a::SArray) = a
@inline _copy(a) = copy(a)

abstract type AbstractSimpleDiffEqODEAlgorithm <: AbstractODEAlgorithm end
isautodifferentiable(alg::AbstractSimpleDiffEqODEAlgorithm) = true
allows_arbitrary_number_types(alg::AbstractSimpleDiffEqODEAlgorithm) = true
allowscomplex(alg::AbstractSimpleDiffEqODEAlgorithm) = true
isadaptive(alg::AbstractSimpleDiffEqODEAlgorithm) = false # except 2, handled individually

function build_adaptive_controller_cache(::Type{T}) where {T}
beta1 = T(7 / 50)
beta2 = T(2 / 25)
qmax = T(10.0)
qmin = T(1 / 5)
gamma = T(9 / 10)
qoldinit = T(1.0e-4)
qold = qoldinit

return beta1, beta2, qmax, qmin, gamma, qoldinit, qold
end

include("functionmap.jl")
include("euler_maruyama.jl")
include("rk4/rk4.jl")
include("rk4/gpurk4.jl")
include("rk4/looprk4.jl")
include("euler/euler.jl")
include("euler/gpueuler.jl")
include("euler/loopeuler.jl")
include("tsit5/atsit5_cache.jl")
include("tsit5/tsit5.jl")
include("tsit5/atsit5.jl")
include("tsit5/gpuatsit5.jl")
include("verner/verner_tableaus.jl")
include("verner/gpuvern7.jl")
include("verner/gpuvern9.jl")

end # module
48 changes: 31 additions & 17 deletions src/euler/euler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ struct SimpleEuler <: AbstractSimpleDiffEqODEAlgorithm end
export SimpleEuler

mutable struct SimpleEulerIntegrator{IIP, S, T, P, F} <:
DiffEqBase.AbstractODEIntegrator{SimpleEuler, IIP, S, T}
DiffEqBase.AbstractODEIntegrator{SimpleEuler, IIP, S, T}
f::F # ..................................... Equations of motion
uprev::S # .......................................... Previous state
u::S # ........................................... Current state
Expand All @@ -71,18 +71,24 @@ DiffEqBase.isinplace(::SEI{IIP}) where {IIP} = IIP
# Initialization
################################################################################

function DiffEqBase.__init(prob::ODEProblem, alg::SimpleEuler;
dt = error("dt is required for this algorithm"))
simpleeuler_init(prob.f,
function DiffEqBase.__init(
prob::ODEProblem, alg::SimpleEuler;
dt = error("dt is required for this algorithm")
)
return simpleeuler_init(
prob.f,
DiffEqBase.isinplace(prob),
prob.u0,
prob.tspan[1],
dt,
prob.p)
prob.p
)
end

function DiffEqBase.__solve(prob::ODEProblem, alg::SimpleEuler;
dt = error("dt is required for this algorithm"))
function DiffEqBase.__solve(
prob::ODEProblem, alg::SimpleEuler;
dt = error("dt is required for this algorithm")
)
u0 = prob.u0
tspan = prob.tspan
ts = Array(tspan[1]:dt:tspan[2])
Expand All @@ -91,8 +97,10 @@ function DiffEqBase.__solve(prob::ODEProblem, alg::SimpleEuler;

@inbounds us[1] = _copy(u0)

integ = simpleeuler_init(prob.f, DiffEqBase.isinplace(prob), prob.u0,
prob.tspan[1], dt, prob.p)
integ = simpleeuler_init(
prob.f, DiffEqBase.isinplace(prob), prob.u0,
prob.tspan[1], dt, prob.p
)

for i in 1:(n - 1)
step!(integ)
Expand All @@ -102,17 +110,22 @@ function DiffEqBase.__solve(prob::ODEProblem, alg::SimpleEuler;
sol = DiffEqBase.build_solution(prob, alg, ts, us, calculate_error = false)

DiffEqBase.has_analytic(prob.f) &&
DiffEqBase.calculate_solution_errors!(sol;
timeseries_errors = true,
dense_errors = false)
DiffEqBase.calculate_solution_errors!(
sol;
timeseries_errors = true,
dense_errors = false
)

return sol
end

@inline function simpleeuler_init(f::F, IIP::Bool, u0::S, t0::T, dt::T,
p::P) where
{F, P, T, S}
integ = SEI{IIP, S, T, P, F}(f,
@inline function simpleeuler_init(
f::F, IIP::Bool, u0::S, t0::T, dt::T,
p::P
) where
{F, P, T, S}
integ = SEI{IIP, S, T, P, F}(
f,
_copy(u0),
_copy(u0),
_copy(u0),
Expand All @@ -122,7 +135,8 @@ end
dt,
sign(dt),
p,
true)
true
)

return integ
end
Expand Down
18 changes: 12 additions & 6 deletions src/euler/gpueuler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,11 @@ sol = solve(prob, GPUSimpleEuler(), dt = 0.1)
struct GPUSimpleEuler <: AbstractSimpleDiffEqODEAlgorithm end
export GPUSimpleEuler

@muladd function DiffEqBase.solve(prob::ODEProblem,
@muladd function DiffEqBase.solve(
prob::ODEProblem,
alg::GPUSimpleEuler;
dt = error("dt is required for this algorithm"))
dt = error("dt is required for this algorithm")
)
@assert !isinplace(prob)
u0 = prob.u0
tspan = prob.tspan
Expand All @@ -67,11 +69,15 @@ export GPUSimpleEuler
us[i] = u
end

sol = DiffEqBase.build_solution(prob, alg, ts, SArray(us),
sol = DiffEqBase.build_solution(
prob, alg, ts, SArray(us),
k = nothing, stats = nothing,
calculate_error = false)
calculate_error = false
)
DiffEqBase.has_analytic(prob.f) &&
DiffEqBase.calculate_solution_errors!(sol; timeseries_errors = true,
dense_errors = false)
DiffEqBase.calculate_solution_errors!(
sol; timeseries_errors = true,
dense_errors = false
)
sol
end
64 changes: 38 additions & 26 deletions src/euler/loopeuler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,17 @@ export LoopEuler

# Out-of-place
# No caching, good for static arrays, bad for arrays
@muladd function DiffEqBase.__solve(prob::ODEProblem{uType, tType, false},
alg::LoopEuler;
dt = error("dt is required for this algorithm"),
save_everystep = true,
save_start = true,
adaptive = false,
dense = false,
save_end = true,
kwargs...) where {uType, tType}
@muladd function DiffEqBase.__solve(
prob::ODEProblem{uType, tType, false},
alg::LoopEuler;
dt = error("dt is required for this algorithm"),
save_everystep = true,
save_start = true,
adaptive = false,
dense = false,
save_end = true,
kwargs...
) where {uType, tType}
@assert !adaptive
@assert !dense
u0 = prob.u0
Expand Down Expand Up @@ -90,27 +92,33 @@ export LoopEuler

!save_everystep && save_end && (us[end] = u)

sol = DiffEqBase.build_solution(prob, alg, ts, us,
sol = DiffEqBase.build_solution(
prob, alg, ts, us,
k = nothing, stats = nothing,
calculate_error = false)
calculate_error = false
)
DiffEqBase.has_analytic(prob.f) &&
DiffEqBase.calculate_solution_errors!(sol; timeseries_errors = true,
dense_errors = false)
DiffEqBase.calculate_solution_errors!(
sol; timeseries_errors = true,
dense_errors = false
)
sol
end

# In-place
# Good for mutable objects like arrays
# Use DiffEqBase.@.. for simd ivdep
@muladd function DiffEqBase.solve(prob::ODEProblem{uType, tType, true},
alg::LoopEuler;
dt = error("dt is required for this algorithm"),
save_everystep = true,
save_start = true,
adaptive = false,
dense = false,
save_end = true,
kwargs...) where {uType, tType}
@muladd function DiffEqBase.solve(
prob::ODEProblem{uType, tType, true},
alg::LoopEuler;
dt = error("dt is required for this algorithm"),
save_everystep = true,
save_start = true,
adaptive = false,
dense = false,
save_end = true,
kwargs...
) where {uType, tType}
@assert !adaptive
@assert !dense
u0 = prob.u0
Expand Down Expand Up @@ -145,11 +153,15 @@ end

!save_everystep && save_end && (us[end] = u)

sol = DiffEqBase.build_solution(prob, alg, ts, us,
sol = DiffEqBase.build_solution(
prob, alg, ts, us,
k = nothing, stats = nothing,
calculate_error = false)
calculate_error = false
)
DiffEqBase.has_analytic(prob.f) &&
DiffEqBase.calculate_solution_errors!(sol; timeseries_errors = true,
dense_errors = false)
DiffEqBase.calculate_solution_errors!(
sol; timeseries_errors = true,
dense_errors = false
)
sol
end
Loading
Loading