Skip to content

Commit 6657441

Browse files
sunxd3github-actions[bot]penelopeysm
authored
Move predict from Turing (#716)
* move `predict` from Turing * minor fixes * Update test/ext/DynamicPPLMCMCChainsExt.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fix test error by discard burn-in's * add some comments * fix test error * Update test/ext/DynamicPPLMCMCChainsExt.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * refactor the code; add `predict` in Turing that takes array of varinfos * Update model.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * stop using `PredictiveSample` type * use NamedTuple * remove predict with varinfos function * update implementation and tests; no longer using AdvancedHMC * try fixing naming conflict --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Penelope Yong <[email protected]>
1 parent d0cfaaf commit 6657441

8 files changed

+278
-7
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ DynamicPPLZygoteRulesExt = ["ZygoteRules"]
4646
[compat]
4747
ADTypes = "1"
4848
AbstractMCMC = "5"
49-
AbstractPPL = "0.8.4, 0.9"
49+
AbstractPPL = "0.10.1"
5050
Accessors = "0.1"
5151
BangBang = "0.4.1"
5252
Bijectors = "0.13.18, 0.14, 0.15"

ext/DynamicPPLMCMCChainsExt.jl

+142
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,148 @@ function DynamicPPL.varnames(c::MCMCChains.Chains)
4242
return keys(c.info.varname_to_symbol)
4343
end
4444

45+
"""
46+
predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false)
47+
48+
Sample from the posterior predictive distribution by executing `model` with parameters fixed to each sample
49+
in `chain`, and return the resulting `Chains`.
50+
51+
The `model` passed to `predict` is often different from the one used to generate `chain`.
52+
Typically, the model from which `chain` originated treats certain variables as observed (i.e.,
53+
data points), while the model you pass to `predict` may mark these same variables as missing
54+
or unobserved. Calling `predict` then leverages the previously inferred parameter values to
55+
simulate what new, unobserved data might look like, given your posterior beliefs.
56+
57+
For each parameter configuration in `chain`:
58+
1. All random variables present in `chain` are fixed to their sampled values.
59+
2. Any variables not included in `chain` are sampled from their prior distributions.
60+
61+
If `include_all` is `false`, the returned `Chains` will contain only those variables that were not fixed by
62+
the samples in `chain`. This is useful when you want to sample only new variables from the posterior
63+
predictive distribution.
64+
65+
# Examples
66+
```jldoctest
67+
using AbstractMCMC, Distributions, DynamicPPL, Random
68+
69+
@model function linear_reg(x, y, σ = 0.1)
70+
β ~ Normal(0, 1)
71+
for i in eachindex(y)
72+
y[i] ~ Normal(β * x[i], σ)
73+
end
74+
end
75+
76+
# Generate synthetic chain using known ground truth parameter
77+
ground_truth_β = 2.0
78+
79+
# Create chain of samples from a normal distribution centered on ground truth
80+
β_chain = MCMCChains.Chains(
81+
rand(Normal(ground_truth_β, 0.002), 1000), [:β,]
82+
)
83+
84+
# Generate predictions for two test points
85+
xs_test = [10.1, 10.2]
86+
87+
m_train = linear_reg(xs_test, fill(missing, length(xs_test)))
88+
89+
predictions = DynamicPPL.AbstractPPL.predict(
90+
Random.default_rng(), m_train, β_chain
91+
)
92+
93+
ys_pred = vec(mean(Array(predictions); dims=1))
94+
95+
# Check if predictions match expected values within tolerance
96+
(
97+
isapprox(ys_pred[1], ground_truth_β * xs_test[1], atol = 0.01),
98+
isapprox(ys_pred[2], ground_truth_β * xs_test[2], atol = 0.01)
99+
)
100+
101+
# output
102+
103+
(true, true)
104+
```
105+
"""
106+
function DynamicPPL.predict(
107+
rng::DynamicPPL.Random.AbstractRNG,
108+
model::DynamicPPL.Model,
109+
chain::MCMCChains.Chains;
110+
include_all=false,
111+
)
112+
parameter_only_chain = MCMCChains.get_sections(chain, :parameters)
113+
varinfo = DynamicPPL.VarInfo(model)
114+
115+
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
116+
predictive_samples = map(iters) do (sample_idx, chain_idx)
117+
DynamicPPL.setval_and_resample!(varinfo, parameter_only_chain, sample_idx, chain_idx)
118+
model(rng, varinfo, DynamicPPL.SampleFromPrior())
119+
120+
vals = DynamicPPL.values_as_in_model(model, varinfo)
121+
varname_vals = mapreduce(
122+
collect,
123+
vcat,
124+
map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)),
125+
)
126+
127+
return (varname_and_values=varname_vals, logp=DynamicPPL.getlogp(varinfo))
128+
end
129+
130+
chain_result = reduce(
131+
MCMCChains.chainscat,
132+
[
133+
_predictive_samples_to_chains(predictive_samples[:, chain_idx]) for
134+
chain_idx in 1:size(predictive_samples, 2)
135+
],
136+
)
137+
parameter_names = if include_all
138+
MCMCChains.names(chain_result, :parameters)
139+
else
140+
filter(
141+
k -> !(k in MCMCChains.names(parameter_only_chain, :parameters)),
142+
names(chain_result, :parameters),
143+
)
144+
end
145+
return chain_result[parameter_names]
146+
end
147+
148+
function _predictive_samples_to_arrays(predictive_samples)
149+
variable_names_set = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}()
150+
151+
sample_dicts = map(predictive_samples) do sample
152+
varname_value_pairs = sample.varname_and_values
153+
varnames = map(first, varname_value_pairs)
154+
values = map(last, varname_value_pairs)
155+
for varname in varnames
156+
push!(variable_names_set, varname)
157+
end
158+
159+
return DynamicPPL.OrderedCollections.OrderedDict(zip(varnames, values))
160+
end
161+
162+
variable_names = collect(variable_names_set)
163+
variable_values = [
164+
get(sample_dicts[i], key, missing) for i in eachindex(sample_dicts),
165+
key in variable_names
166+
]
167+
168+
return variable_names, variable_values
169+
end
170+
171+
function _predictive_samples_to_chains(predictive_samples)
172+
variable_names, variable_values = _predictive_samples_to_arrays(predictive_samples)
173+
variable_names_symbols = map(Symbol, variable_names)
174+
175+
internal_parameters = [:lp]
176+
log_probabilities = reshape([sample.logp for sample in predictive_samples], :, 1)
177+
178+
parameter_names = [variable_names_symbols; internal_parameters]
179+
parameter_values = hcat(variable_values, log_probabilities)
180+
parameter_values = MCMCChains.concretize(parameter_values)
181+
182+
return MCMCChains.Chains(
183+
parameter_values, parameter_names, (internals=internal_parameters,)
184+
)
185+
end
186+
45187
"""
46188
returned(model::Model, chain::MCMCChains.Chains)
47189

src/DynamicPPL.jl

+3-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ using AbstractPPL
55
using Bijectors
66
using Compat
77
using Distributions
8-
using OrderedCollections: OrderedDict
8+
using OrderedCollections: OrderedCollections, OrderedDict
99

1010
using AbstractMCMC: AbstractMCMC
1111
using ADTypes: ADTypes
@@ -40,6 +40,8 @@ import Base:
4040
keys,
4141
haskey
4242

43+
import AbstractPPL: predict
44+
4345
# VarInfo
4446
export AbstractVarInfo,
4547
VarInfo,

src/model.jl

+20
Original file line numberDiff line numberDiff line change
@@ -1144,6 +1144,26 @@ function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractC
11441144
end
11451145
end
11461146

1147+
"""
1148+
predict([rng::AbstractRNG,] model::Model, chain::AbstractVector{<:AbstractVarInfo})
1149+
1150+
Generate samples from the posterior predictive distribution by evaluating `model` at each set
1151+
of parameter values provided in `chain`. The number of posterior predictive samples matches
1152+
the length of `chain`. The returned `AbstractVarInfo`s will contain both the posterior parameter values
1153+
and the predicted values.
1154+
"""
1155+
function predict(
1156+
rng::Random.AbstractRNG, model::Model, chain::AbstractArray{<:AbstractVarInfo}
1157+
)
1158+
varinfo = DynamicPPL.VarInfo(model)
1159+
return map(chain) do params_varinfo
1160+
vi = deepcopy(varinfo)
1161+
DynamicPPL.setval_and_resample!(vi, values_as(params_varinfo, NamedTuple))
1162+
model(rng, vi, SampleFromPrior())
1163+
return vi
1164+
end
1165+
end
1166+
11471167
"""
11481168
returned(model::Model, parameters::NamedTuple)
11491169
returned(model::Model, values, keys)

test/Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3333
[compat]
3434
ADTypes = "1"
3535
AbstractMCMC = "5"
36-
AbstractPPL = "0.8.4, 0.9"
36+
AbstractPPL = "0.10.1"
3737
Accessors = "0.1"
3838
Bijectors = "0.15.1"
3939
Combinatorics = "1"

test/contexts.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -202,27 +202,27 @@ end
202202
s, m = retval.s, retval.m
203203

204204
# Keword approach.
205-
model_fixed = fix(model; s=s)
205+
model_fixed = DynamicPPL.fix(model; s=s)
206206
@test model_fixed().s == s
207207
@test model_fixed().m != m
208208
# A fixed variable should not contribute at all to the logjoint.
209209
# Assuming `condition` is correctly implemented, the following should hold.
210210
@test logprior(model_fixed, (; m)) == logprior(condition(model; s=s), (; m))
211211

212212
# Positional approach.
213-
model_fixed = fix(model, (; s))
213+
model_fixed = DynamicPPL.fix(model, (; s))
214214
@test model_fixed().s == s
215215
@test model_fixed().m != m
216216
@test logprior(model_fixed, (; m)) == logprior(condition(model; s=s), (; m))
217217

218218
# Pairs approach.
219-
model_fixed = fix(model, @varname(s) => s)
219+
model_fixed = DynamicPPL.fix(model, @varname(s) => s)
220220
@test model_fixed().s == s
221221
@test model_fixed().m != m
222222
@test logprior(model_fixed, (; m)) == logprior(condition(model; s=s), (; m))
223223

224224
# Dictionary approach.
225-
model_fixed = fix(model, Dict(@varname(s) => s))
225+
model_fixed = DynamicPPL.fix(model, Dict(@varname(s) => s))
226226
@test model_fixed().s == s
227227
@test model_fixed().m != m
228228
@test logprior(model_fixed, (; m)) == logprior(condition(model; s=s), (; m))

test/ext/DynamicPPLMCMCChainsExt.jl

+2
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,5 @@
77
@test size(chain_generated) == (1000, 1)
88
@test mean(chain_generated) 0 atol = 0.1
99
end
10+
11+
# test for `predict` is in `test/model.jl`

test/model.jl

+105
Original file line numberDiff line numberDiff line change
@@ -429,4 +429,109 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
429429
@test getlogp(varinfo_linked) getlogp(varinfo_linked_result)
430430
end
431431
end
432+
433+
@testset "predict" begin
434+
@testset "with MCMCChains.Chains" begin
435+
DynamicPPL.Random.seed!(100)
436+
437+
@model function linear_reg(x, y, σ=0.1)
438+
β ~ Normal(0, 1)
439+
for i in eachindex(y)
440+
y[i] ~ Normal* x[i], σ)
441+
end
442+
end
443+
444+
@model function linear_reg_vec(x, y, σ=0.1)
445+
β ~ Normal(0, 1)
446+
return y ~ MvNormal.* x, σ^2 * I)
447+
end
448+
449+
ground_truth_β = 2
450+
β_chain = MCMCChains.Chains(rand(Normal(ground_truth_β, 0.002), 1000), [])
451+
452+
xs_test = [10 + 0.1, 10 + 2 * 0.1]
453+
m_lin_reg_test = linear_reg(xs_test, fill(missing, length(xs_test)))
454+
predictions = DynamicPPL.predict(m_lin_reg_test, β_chain)
455+
456+
ys_pred = vec(mean(Array(group(predictions, :y)); dims=1))
457+
@test ys_pred[1] ground_truth_β * xs_test[1] atol = 0.01
458+
@test ys_pred[2] ground_truth_β * xs_test[2] atol = 0.01
459+
460+
# Ensure that `rng` is respected
461+
rng = MersenneTwister(42)
462+
predictions1 = DynamicPPL.predict(rng, m_lin_reg_test, β_chain[1:2])
463+
predictions2 = DynamicPPL.predict(
464+
MersenneTwister(42), m_lin_reg_test, β_chain[1:2]
465+
)
466+
@test all(Array(predictions1) .== Array(predictions2))
467+
468+
# Predict on two last indices for vectorized
469+
m_lin_reg_test = linear_reg_vec(xs_test, missing)
470+
predictions_vec = DynamicPPL.predict(m_lin_reg_test, β_chain)
471+
ys_pred_vec = vec(mean(Array(group(predictions_vec, :y)); dims=1))
472+
473+
@test ys_pred_vec[1] ground_truth_β * xs_test[1] atol = 0.01
474+
@test ys_pred_vec[2] ground_truth_β * xs_test[2] atol = 0.01
475+
476+
# Multiple chains
477+
multiple_β_chain = MCMCChains.Chains(
478+
reshape(rand(Normal(ground_truth_β, 0.002), 1000, 2), 1000, 1, 2), []
479+
)
480+
m_lin_reg_test = linear_reg(xs_test, fill(missing, length(xs_test)))
481+
predictions = DynamicPPL.predict(m_lin_reg_test, multiple_β_chain)
482+
@test size(multiple_β_chain, 3) == size(predictions, 3)
483+
484+
for chain_idx in MCMCChains.chains(multiple_β_chain)
485+
ys_pred = vec(mean(Array(group(predictions[:, :, chain_idx], :y)); dims=1))
486+
@test ys_pred[1] ground_truth_β * xs_test[1] atol = 0.01
487+
@test ys_pred[2] ground_truth_β * xs_test[2] atol = 0.01
488+
end
489+
490+
# Predict on two last indices for vectorized
491+
m_lin_reg_test = linear_reg_vec(xs_test, missing)
492+
predictions_vec = DynamicPPL.predict(m_lin_reg_test, multiple_β_chain)
493+
494+
for chain_idx in MCMCChains.chains(multiple_β_chain)
495+
ys_pred_vec = vec(
496+
mean(Array(group(predictions_vec[:, :, chain_idx], :y)); dims=1)
497+
)
498+
@test ys_pred_vec[1] ground_truth_β * xs_test[1] atol = 0.01
499+
@test ys_pred_vec[2] ground_truth_β * xs_test[2] atol = 0.01
500+
end
501+
end
502+
503+
@testset "with AbstractVector{<:AbstractVarInfo}" begin
504+
@model function linear_reg(x, y, σ=0.1)
505+
β ~ Normal(1, 1)
506+
for i in eachindex(y)
507+
y[i] ~ Normal* x[i], σ)
508+
end
509+
end
510+
511+
ground_truth_β = 2.0
512+
# the data will be ignored, as we are generating samples from the prior
513+
xs_train = 1:0.1:10
514+
ys_train = ground_truth_β .* xs_train + rand(Normal(0, 0.1), length(xs_train))
515+
m_lin_reg = linear_reg(xs_train, ys_train)
516+
chain = [evaluate!!(m_lin_reg)[2] for _ in 1:10000]
517+
518+
# chain is generated from the prior
519+
@test mean([chain[i][@varname(β)] for i in eachindex(chain)]) 1.0 atol = 0.1
520+
521+
xs_test = [10 + 0.1, 10 + 2 * 0.1]
522+
m_lin_reg_test = linear_reg(xs_test, fill(missing, length(xs_test)))
523+
predicted_vis = DynamicPPL.predict(m_lin_reg_test, chain)
524+
525+
@test size(predicted_vis) == size(chain)
526+
@test Set(keys(predicted_vis[1])) ==
527+
Set([@varname(β), @varname(y[1]), @varname(y[2])])
528+
# because β samples are from the prior, the std will be larger
529+
@test mean([
530+
predicted_vis[i][@varname(y[1])] for i in eachindex(predicted_vis)
531+
]) 1.0 * xs_test[1] rtol = 0.1
532+
@test mean([
533+
predicted_vis[i][@varname(y[2])] for i in eachindex(predicted_vis)
534+
]) 1.0 * xs_test[2] rtol = 0.1
535+
end
536+
end
432537
end

0 commit comments

Comments
 (0)