Skip to content

Commit 3a21368

Browse files
committed
Add a variety of tests
1 parent b3e1702 commit 3a21368

1 file changed

Lines changed: 38 additions & 5 deletions

File tree

test/mcmc/particle_mcmc.jl

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -161,26 +161,59 @@ end
161161
@test mean(c[:x]) > 0.7
162162
end
163163

164-
# https://github.com/TuringLang/Turing.jl/issues/2007
165164
@testset "keyword argument handling" begin
166165
@model function kwarg_demo(y; n=0.0)
167166
x ~ Normal(n)
168167
return y ~ Normal(x)
169168
end
170-
@test_throws "Models with keyword arguments" sample(kwarg_demo(5.0), PG(20), 10)
171169

172-
# Check that enabling `might_produce` does allow sampling
173-
@might_produce kwarg_demo
174170
chain = sample(StableRNG(468), kwarg_demo(5.0), PG(20), 1000)
175171
@test chain isa MCMCChains.Chains
176172
@test mean(chain[:x]) 2.5 atol = 0.2
177173

178-
# Check that the keyword argument's value is respected
179174
chain2 = sample(StableRNG(468), kwarg_demo(5.0; n=10.0), PG(20), 1000)
180175
@test chain2 isa MCMCChains.Chains
181176
@test mean(chain2[:x]) 7.5 atol = 0.2
182177
end
183178

179+
@testset "submodels without kwargs" begin
180+
@model function inner(y, x)
181+
# Mark as noinline explicitly to make sure that behaviour is not reliant on the
182+
# Julia compiler inlining it.
183+
# See https://github.com/TuringLang/Turing.jl/issues/2772
184+
@noinline
185+
return y ~ Normal(x)
186+
end
187+
@model function nested(y)
188+
x ~ Normal()
189+
return a ~ to_submodel(inner(y, x))
190+
end
191+
m1 = nested(1.0)
192+
chn = sample(StableRNG(468), m1, PG(10), 1000)
193+
@test mean(chn[:x]) 0.5 atol = 0.1
194+
end
195+
196+
@testset "submodels with kwargs" begin
197+
@model function inner_kwarg(y; n=0.0)
198+
@noinline # See above
199+
x ~ Normal(n)
200+
return y ~ Normal(x)
201+
end
202+
@model function outer_kwarg1()
203+
return a ~ to_submodel(inner_kwarg(5.0))
204+
end
205+
m1 = outer_kwarg1()
206+
chn1 = sample(StableRNG(468), m1, PG(10), 1000)
207+
@test mean(chn1[Symbol("a.x")]) 2.5 atol = 0.2
208+
209+
@model function outer_kwarg2(n)
210+
return a ~ to_submodel(inner_kwarg(5.0; n=n))
211+
end
212+
m2 = outer_kwarg2(10.0)
213+
chn2 = sample(StableRNG(468), m2, PG(10), 1000)
214+
@test mean(chn2[Symbol("a.x")]) 7.5 atol = 0.2
215+
end
216+
184217
@testset "refuses to run threadsafe eval" begin
185218
# PG can't run models that have nondeterministic evaluation order,
186219
# so it should refuse to run models marked as threadsafe.

0 commit comments

Comments
 (0)