Skip to content

Commit 19b556b

Browse files
committed
strict allowed kwargs
1 parent 00485a1 commit 19b556b

File tree

2 files changed

+34
-0
lines changed

2 files changed

+34
-0
lines changed

src/Integrals.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,31 @@ function transformation_if_inf(prob, do_inf_transformation = nothing)
198198
transformation_if_inf(prob, do_inf_transformation)
199199
end
200200

201+
const allowedkeywords = (:maxiters, :abstol, :reltol)
202+
const KWARGERROR_MESSAGE = """
203+
Unrecognized keyword arguments found.
204+
The only allowed keyword arguments to `solve` are:
205+
$allowedkeywords
206+
See https://docs.sciml.ai/Integrals/stable/basics/solve/ for more details.
207+
"""
208+
struct CommonKwargError <: Exception
209+
kwargs::Any
210+
end
211+
function Base.showerror(io::IO, e::CommonKwargError)
212+
println(io, KWARGERROR_MESSAGE)
213+
notin = collect(map(x -> x.first allowedkeywords, e.kwargs))
214+
unrecognized = collect(map(x -> x.first, e.kwargs))[notin]
215+
print(io, "Unrecognized keyword arguments: ")
216+
printstyled(io, unrecognized; bold = true, color = :red)
217+
print(io, "\n\n")
218+
end
219+
function checkkwargs(kwargs...)
220+
if any(x -> x.first allowedkeywords, kwargs)
221+
throw(CommonKwargError(kwargs))
222+
end
223+
return nothing
224+
end
225+
ChainRulesCore.@non_differentiable checkkwargs(kwargs...)
201226
function SciMLBase.solve(prob::IntegralProblem; sensealg = ReCallVJP(ZygoteVJP()),
202227
do_inf_transformation = nothing, kwargs...)
203228
if prob.batch != 0
@@ -241,6 +266,7 @@ function SciMLBase.solve(prob::IntegralProblem,
241266
alg::SciMLBase.AbstractIntegralAlgorithm;
242267
sensealg = ReCallVJP(ZygoteVJP()),
243268
do_inf_transformation = nothing, kwargs...)
269+
checkkwargs(kwargs...)
244270
prob = transformation_if_inf(prob, do_inf_transformation)
245271
__solvebp(prob, alg, sensealg, prob.lb, prob.ub, prob.p; kwargs...)
246272
end

test/interface_tests.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,3 +343,11 @@ end
343343
end
344344
end
345345
end
346+
347+
@testset "Allowed keyword test" begin
348+
f(u, p) = sum(sin.(u))
349+
prob = IntegralProblem(f, ones(3), 3ones(3))
350+
@test_throws "Unrecognized keyword arguments found." solve(prob, HCubatureJL();
351+
relztol = 1e-3,
352+
abstol = 1e-3)
353+
end

0 commit comments

Comments
 (0)