@@ -161,26 +161,59 @@ end
161161 @test mean (c[:x ]) > 0.7
162162 end
163163
164- # https://github.com/TuringLang/Turing.jl/issues/2007
165164 @testset " keyword argument handling" begin
166165 @model function kwarg_demo (y; n= 0.0 )
167166 x ~ Normal (n)
168167 return y ~ Normal (x)
169168 end
170- @test_throws " Models with keyword arguments" sample (kwarg_demo (5.0 ), PG (20 ), 10 )
171169
172- # Check that enabling `might_produce` does allow sampling
173- @might_produce kwarg_demo
174170 chain = sample (StableRNG (468 ), kwarg_demo (5.0 ), PG (20 ), 1000 )
175171 @test chain isa MCMCChains. Chains
176172 @test mean (chain[:x ]) ≈ 2.5 atol = 0.2
177173
178- # Check that the keyword argument's value is respected
179174 chain2 = sample (StableRNG (468 ), kwarg_demo (5.0 ; n= 10.0 ), PG (20 ), 1000 )
180175 @test chain2 isa MCMCChains. Chains
181176 @test mean (chain2[:x ]) ≈ 7.5 atol = 0.2
182177 end
183178
179+ @testset " submodels without kwargs" begin
180+ @model function inner (y, x)
181+ # Mark as noinline explicitly to make sure that behaviour is not reliant on the
182+ # Julia compiler inlining it.
183+ # See https://github.com/TuringLang/Turing.jl/issues/2772
184+ @noinline
185+ return y ~ Normal (x)
186+ end
187+ @model function nested (y)
188+ x ~ Normal ()
189+ return a ~ to_submodel (inner (y, x))
190+ end
191+ m1 = nested (1.0 )
192+ chn = sample (StableRNG (468 ), m1, PG (10 ), 1000 )
193+ @test mean (chn[:x ]) ≈ 0.5 atol = 0.1
194+ end
195+
196+ @testset " submodels with kwargs" begin
197+ @model function inner_kwarg (y; n= 0.0 )
198+ @noinline # See above
199+ x ~ Normal (n)
200+ return y ~ Normal (x)
201+ end
202+ @model function outer_kwarg1 ()
203+ return a ~ to_submodel (inner_kwarg (5.0 ))
204+ end
205+ m1 = outer_kwarg1 ()
206+ chn1 = sample (StableRNG (468 ), m1, PG (10 ), 1000 )
207+ @test mean (chn1[Symbol (" a.x" )]) ≈ 2.5 atol = 0.2
208+
209+ @model function outer_kwarg2 (n)
210+ return a ~ to_submodel (inner_kwarg (5.0 ; n= n))
211+ end
212+ m2 = outer_kwarg2 (10.0 )
213+ chn2 = sample (StableRNG (468 ), m2, PG (10 ), 1000 )
214+ @test mean (chn2[Symbol (" a.x" )]) ≈ 7.5 atol = 0.2
215+ end
216+
184217 @testset " refuses to run threadsafe eval" begin
185218 # PG can't run models that have nondeterministic evaluation order,
186219 # so it should refuse to run models marked as threadsafe.
0 commit comments