diff --git a/src/sde_default_alg.jl b/src/sde_default_alg.jl index d084cca8f..d789e5a94 100644 --- a/src/sde_default_alg.jl +++ b/src/sde_default_alg.jl @@ -9,12 +9,14 @@ function default_algorithm(prob::DiffEqBase.AbstractSDEProblem{uType,tType,isinp alg = RKMilCommute() end - if :stiff ∈ alg_hints + is_stiff = :stiff ∈ alg_hints + is_stratonovich = :stratonovich ∈ alg_hints + if is_stiff || prob.f.mass_matrix !== I alg = ImplicitRKMil(autodiff=false) end - if :stratonovich ∈ alg_hints - if :stiff ∈ alg_hints + if is_stratonovich + if is_stiff || prob.f.mass_matrix !== I alg = ImplicitRKMil(autodiff=false,interpretation=:stratonovich) else alg = RKMil(interpretation=:stratonovich) @@ -22,14 +24,14 @@ function default_algorithm(prob::DiffEqBase.AbstractSDEProblem{uType,tType,isinp end if prob.noise_rate_prototype != nothing || prob.noise != nothing - if :stratonovich ∈ alg_hints - if :stiff ∈ alg_hints + if is_stratonovich + if is_stiff || prob.f.mass_matrix !== I alg = ImplicitEulerHeun(autodiff=false) else alg = LambaEulerHeun() end else - if :stiff ∈ alg_hints + if is_stiff || prob.f.mass_matrix !== I alg = ISSEM(autodiff=false) else alg = LambaEM() @@ -38,7 +40,7 @@ function default_algorithm(prob::DiffEqBase.AbstractSDEProblem{uType,tType,isinp end if :additive ∈ alg_hints - if :stiff ∈ alg_hints + if is_stiff || prob.f.mass_matrix !== I alg = SKenCarp(autodiff=false) else alg = SOSRA()