@@ -7,6 +7,8 @@ import ..ADUtils
77using Bijectors: Bijectors
88using Distributions: Bernoulli, Beta, Categorical, Dirichlet, Normal, Wishart, sample
99using DynamicPPL: DynamicPPL, Sampler
10+ using DynamicPPL. TestUtils. AD: run_ad
11+ using DynamicPPL. TestUtils: DEMO_MODELS
1012import ForwardDiff
1113using HypothesisTests: ApproximateTwoSampleKSTest, pvalue
1214import ReverseDiff
@@ -18,9 +20,41 @@ import Mooncake
1820using Test: @test , @test_logs , @testset , @test_throws
1921using Turing
2022
21- @testset " Testing hmc.jl with $adbackend " for adbackend in ADUtils. adbackends
22- @info " Starting HMC tests with $adbackend "
23+ @testset " AD / hmc.jl" begin
24+ # AD tests need to be run with SamplingContext because samplers can potentially
25+ # use this to define custom behaviour in the tilde-pipeline and thus change the
26+ # code executed during model evaluation.
27+ @testset " adtype=$adtype " for adtype in ADUtils. adbackends
28+ @testset " alg=$alg " for alg in [
29+ HMC (0.1 , 10 ; adtype= adtype),
30+ HMCDA (0.8 , 0.75 ; adtype= adtype),
31+ NUTS (1000 , 0.8 ; adtype= adtype),
32+ ]
33+ @info " Testing AD for $alg "
34+
35+ @testset " model=$(model. f) " for model in DEMO_MODELS
36+ rng = StableRNG (123 )
37+ ctx = DynamicPPL. SamplingContext (rng, DynamicPPL. Sampler (alg))
38+ @test run_ad (model, adtype; context= ctx, test= true , benchmark= false ) isa Any
39+ end
40+ end
41+
42+ @testset " Check ADType" begin
43+ seed = 123
44+ alg = HMC (0.1 , 10 ; adtype= adbackend)
45+ m = DynamicPPL. contextualize (
46+ gdemo_default, ADTypeCheckContext (adbackend, gdemo_default. context)
47+ )
48+ # These will error if the adbackend being used is not the one set.
49+ sample (StableRNG (seed), m, alg, 10 )
50+ end
51+ end
52+ end
53+
54+ @testset " Testing hmc.jl" begin
55+ @info " Starting HMC tests"
2356 seed = 123
57+ adbackend = Turing. DEFAULT_ADTYPE
2458
2559 @testset " constrained bounded" begin
2660 obs = [0 , 1 , 0 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ]
@@ -65,12 +99,6 @@ using Turing
6599 check_numerical (chain, [" ps[1]" , " ps[2]" ], [5 / 16 , 11 / 16 ]; atol= 0.015 )
66100 end
67101
68- @testset " hmc reverse diff" begin
69- alg = HMC (0.1 , 10 ; adtype= adbackend)
70- res = sample (StableRNG (seed), gdemo_default, alg, 4_000 )
71- check_gdemo (res; rtol= 0.1 )
72- end
73-
74102 # Test the sampling of a matrix-value distribution.
75103 @testset " matrix support" begin
76104 dist = Wishart (7 , [1 0.5 ; 0.5 1 ])
@@ -211,20 +239,20 @@ using Turing
211239 end
212240
213241 @testset " prior" begin
242+ # NOTE: Used to use `InverseGamma(2, 3)` but this has infinite variance
243+ # which means that it's _very_ difficult to find a good tolerance in the test below:)
244+ prior_dist = truncated (Normal (3 , 1 ); lower= 0 )
245+
214246 @model function demo_hmc_prior ()
215- # NOTE: Used to use `InverseGamma(2, 3)` but this has infinite variance
216- # which means that it's _very_ difficult to find a good tolerance in the test below:)
217- s ~ truncated (Normal (3 , 1 ); lower= 0 )
247+ s ~ prior_dist
218248 return m ~ Normal (0 , sqrt (s))
219249 end
220250 alg = NUTS (1000 , 0.8 ; adtype= adbackend)
221251 gdemo_default_prior = DynamicPPL. contextualize (
222252 demo_hmc_prior (), DynamicPPL. PriorContext ()
223253 )
224254 chain = sample (gdemo_default_prior, alg, 5_000 ; initial_params= [3.0 , 0.0 ])
225- check_numerical (
226- chain, [:s , :m ], [mean (truncated (Normal (3 , 1 ); lower= 0 )), 0 ]; atol= 0.2
227- )
255+ check_numerical (chain, [:s , :m ], [mean (prior_dist), 0 ]; atol= 0.2 )
228256 end
229257
230258 @testset " warning for difficult init params" begin
@@ -292,8 +320,8 @@ using Turing
292320
293321 # Extract the `x` like this because running `generated_quantities` was how
294322 # the issue was discovered, hence we also want to make sure that it works.
295- results = generated_quantities (model, chain)
296- results_prior = generated_quantities (model, chain_prior)
323+ results = returned (model, chain)
324+ results_prior = returned (model, chain_prior)
297325
298326 # Make sure none of the samples in the chains resulted in errors.
299327 @test all (! isnothing, results)
@@ -315,15 +343,6 @@ using Turing
315343 @test Turing. Inference. getstepsize (spl, hmc_state) isa Float64
316344 end
317345 end
318-
319- @testset " Check ADType" begin
320- alg = HMC (0.1 , 10 ; adtype= adbackend)
321- m = DynamicPPL. contextualize (
322- gdemo_default, ADTypeCheckContext (adbackend, gdemo_default. context)
323- )
324- # These will error if the adbackend being used is not the one set.
325- sample (StableRNG (seed), m, alg, 10 )
326- end
327346end
328347
329348end
0 commit comments