@@ -383,7 +383,10 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
383383 example_values = DynamicPPL. TestUtils. rand_prior_true(model)
384384 varinfos = DynamicPPL. TestUtils. setup_varinfos(model, example_values, vns)
385385 @testset " $(short_varinfo_name(varinfo)) " for varinfo in varinfos
386- realizations = values_as_in_model(model, varinfo)
386+ # We can set the include_colon_eq arg to false because none of
387+ # the demo models contain :=. The behaviour when
388+ # include_colon_eq is true is tested in test/compiler.jl
389+ realizations = values_as_in_model(model, false , varinfo)
387390 # Ensure that all variables are found.
388391 vns_found = collect(keys(realizations))
389392 @test vns ∩ vns_found == vns ∪ vns_found
@@ -393,6 +396,22 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
393396 end
394397 end
395398 end
399+
400+ @testset " check that sampling obeys rng if passed" begin
401+ @model function f()
402+ x ~ Normal(0 )
403+ return y ~ Normal(x)
404+ end
405+ model = f()
406+ # Call values_as_in_model with the rng
407+ values = values_as_in_model(Random. Xoshiro(43 ), model, false )
408+ # Check that they match the values that would be used if vi was seeded
409+ # with that seed instead
410+ expected_vi = VarInfo(Random. Xoshiro(43 ), model)
411+ for vn in keys(values)
412+ @test values[vn] == expected_vi[vn]
413+ end
414+ end
396415 end
397416
398417 @testset " Erroneous model call" begin
@@ -432,72 +451,87 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
432451
433452 @testset " predict" begin
434453 @testset " with MCMCChains.Chains" begin
435- DynamicPPL. Random. seed!(100 )
436-
437454 @model function linear_reg(x, y, σ= 0.1 )
438455 β ~ Normal(0 , 1 )
439456 for i in eachindex(y)
440457 y[i] ~ Normal(β * x[i], σ)
441458 end
459+ # Insert a := block to test that it is not included in predictions
460+ return σ2 := σ^ 2
442461 end
443462
444- @model function linear_reg_vec(x, y, σ= 0.1 )
445- β ~ Normal(0 , 1 )
446- return y ~ MvNormal(β .* x, σ^ 2 * I)
447- end
448-
463+ # Construct a chain with 'sampled values' of β
449464 ground_truth_β = 2
450465 β_chain = MCMCChains. Chains(rand(Normal(ground_truth_β, 0.002 ), 1000 ), [:β])
451466
467+ # Generate predictions from that chain
452468 xs_test = [10 + 0.1 , 10 + 2 * 0.1 ]
453469 m_lin_reg_test = linear_reg(xs_test, fill(missing , length(xs_test)))
454470 predictions = DynamicPPL. predict(m_lin_reg_test, β_chain)
455471
456- ys_pred = vec(mean(Array(group(predictions, :y)); dims= 1 ))
457- @test ys_pred[1 ] ≈ ground_truth_β * xs_test[1 ] atol = 0.01
458- @test ys_pred[2 ] ≈ ground_truth_β * xs_test[2 ] atol = 0.01
459-
460- # Ensure that `rng` is respected
461- rng = MersenneTwister(42 )
462- predictions1 = DynamicPPL. predict(rng, m_lin_reg_test, β_chain[1 : 2 ])
463- predictions2 = DynamicPPL. predict(
464- MersenneTwister(42 ), m_lin_reg_test, β_chain[1 : 2 ]
465- )
466- @test all(Array(predictions1) .== Array(predictions2))
467-
468- # Predict on two last indices for vectorized
469- m_lin_reg_test = linear_reg_vec(xs_test, missing )
470- predictions_vec = DynamicPPL. predict(m_lin_reg_test, β_chain)
471- ys_pred_vec = vec(mean(Array(group(predictions_vec, :y)); dims= 1 ))
472-
473- @test ys_pred_vec[1 ] ≈ ground_truth_β * xs_test[1 ] atol = 0.01
474- @test ys_pred_vec[2 ] ≈ ground_truth_β * xs_test[2 ] atol = 0.01
472+ # Also test a vectorized model
473+ @model function linear_reg_vec(x, y, σ= 0.1 )
474+ β ~ Normal(0 , 1 )
475+ return y ~ MvNormal(β .* x, σ^ 2 * I)
476+ end
477+ m_lin_reg_test_vec = linear_reg_vec(xs_test, missing )
475478
476- # Multiple chains
477- multiple_β_chain = MCMCChains. Chains(
478- reshape(rand(Normal(ground_truth_β, 0.002 ), 1000 , 2 ), 1000 , 1 , 2 ), [:β]
479- )
480- m_lin_reg_test = linear_reg(xs_test, fill(missing , length(xs_test)))
481- predictions = DynamicPPL. predict(m_lin_reg_test, multiple_β_chain)
482- @test size(multiple_β_chain, 3 ) == size(predictions, 3 )
479+ @testset " variables in chain" begin
480+ # Note that this also checks that variables on the lhs of :=,
481+ # such as σ2, are not included in the resulting chain
482+ @test Set(keys(predictions)) == Set([Symbol(" y[1]" ), Symbol(" y[2]" )])
483+ end
483484
484- for chain_idx in MCMCChains . chains(multiple_β_chain)
485- ys_pred = vec(mean(Array(group(predictions[:, :, chain_idx] , :y)); dims= 1 ))
485+ @testset " accuracy " begin
486+ ys_pred = vec(mean(Array(group(predictions, :y)); dims= 1 ))
486487 @test ys_pred[1 ] ≈ ground_truth_β * xs_test[1 ] atol = 0.01
487488 @test ys_pred[2 ] ≈ ground_truth_β * xs_test[2 ] atol = 0.01
488489 end
489490
490- # Predict on two last indices for vectorized
491- m_lin_reg_test = linear_reg_vec(xs_test, missing )
492- predictions_vec = DynamicPPL. predict(m_lin_reg_test, multiple_β_chain)
493-
494- for chain_idx in MCMCChains. chains(multiple_β_chain)
495- ys_pred_vec = vec(
496- mean(Array(group(predictions_vec[:, :, chain_idx], :y)); dims= 1 )
491+ @testset " ensure that rng is respected" begin
492+ rng = MersenneTwister(42 )
493+ predictions1 = DynamicPPL. predict(rng, m_lin_reg_test, β_chain[1 : 2 ])
494+ predictions2 = DynamicPPL. predict(
495+ MersenneTwister(42 ), m_lin_reg_test, β_chain[1 : 2 ]
497496 )
497+ @test all(Array(predictions1) .== Array(predictions2))
498+ end
499+
500+ @testset " accuracy on vectorized model" begin
501+ predictions_vec = DynamicPPL. predict(m_lin_reg_test_vec, β_chain)
502+ ys_pred_vec = vec(mean(Array(group(predictions_vec, :y)); dims= 1 ))
503+
498504 @test ys_pred_vec[1 ] ≈ ground_truth_β * xs_test[1 ] atol = 0.01
499505 @test ys_pred_vec[2 ] ≈ ground_truth_β * xs_test[2 ] atol = 0.01
500506 end
507+
508+ @testset " prediction from multiple chains" begin
509+ # Normal linreg model
510+ multiple_β_chain = MCMCChains. Chains(
511+ reshape(rand(Normal(ground_truth_β, 0.002 ), 1000 , 2 ), 1000 , 1 , 2 ), [:β]
512+ )
513+ predictions = DynamicPPL. predict(m_lin_reg_test, multiple_β_chain)
514+ @test size(multiple_β_chain, 3 ) == size(predictions, 3 )
515+
516+ for chain_idx in MCMCChains. chains(multiple_β_chain)
517+ ys_pred = vec(
518+ mean(Array(group(predictions[:, :, chain_idx], :y)); dims= 1 )
519+ )
520+ @test ys_pred[1 ] ≈ ground_truth_β * xs_test[1 ] atol = 0.01
521+ @test ys_pred[2 ] ≈ ground_truth_β * xs_test[2 ] atol = 0.01
522+ end
523+
524+ # Vectorized linreg model
525+ predictions_vec = DynamicPPL. predict(m_lin_reg_test_vec, multiple_β_chain)
526+
527+ for chain_idx in MCMCChains. chains(multiple_β_chain)
528+ ys_pred_vec = vec(
529+ mean(Array(group(predictions_vec[:, :, chain_idx], :y)); dims= 1 )
530+ )
531+ @test ys_pred_vec[1 ] ≈ ground_truth_β * xs_test[1 ] atol = 0.01
532+ @test ys_pred_vec[2 ] ≈ ground_truth_β * xs_test[2 ] atol = 0.01
533+ end
534+ end
501535 end
502536
503537 @testset " with AbstractVector{<:AbstractVarInfo}" begin
0 commit comments