Skip to content

Commit ad848b2

Browse files
authored
Stricter types for evaluate!! methods (#629)
1 parent 3b3840d commit ad848b2

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed

src/model.jl

+7-2
Original file line numberDiff line numberDiff line change
@@ -909,13 +909,18 @@ function AbstractPPL.evaluate!!(model::Model, context::AbstractContext)
909909
return evaluate!!(model, VarInfo(), context)
910910
end
911911

912-
function AbstractPPL.evaluate!!(model::Model, args...)
912+
function AbstractPPL.evaluate!!(
913+
model::Model, args::Union{AbstractVarInfo,AbstractSampler,AbstractContext}...
914+
)
913915
return evaluate!!(model, Random.default_rng(), args...)
914916
end
915917

916918
# without VarInfo
917919
function AbstractPPL.evaluate!!(
918-
model::Model, rng::Random.AbstractRNG, sampler::AbstractSampler, args...
920+
model::Model,
921+
rng::Random.AbstractRNG,
922+
sampler::AbstractSampler,
923+
args::AbstractContext...,
919924
)
920925
return evaluate!!(model, rng, VarInfo(), sampler, args...)
921926
end

test/model.jl

+14
Original file line numberDiff line numberDiff line change
@@ -396,4 +396,18 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
396396
end
397397
end
398398
end
399+
400+
@testset "Erroneous model call" begin
401+
# Calling a model with the wrong arguments used to lead to infinite recursion, see
402+
# https://github.com/TuringLang/Turing.jl/issues/2182. This guards against it.
403+
@model function a_model(x)
404+
m ~ Normal(0, 1)
405+
x ~ Normal(m, 1)
406+
return nothing
407+
end
408+
instance = a_model(1.0)
409+
# `instance` should be called with rng, context, etc., but one may easily get
410+
# confused and call it the way you are meant to call `a_model`.
411+
@test_throws MethodError instance(1.0)
412+
end
399413
end

0 commit comments

Comments
 (0)