@@ -42,6 +42,148 @@ function DynamicPPL.varnames(c::MCMCChains.Chains)
42
42
return keys (c. info. varname_to_symbol)
43
43
end
44
44
45
+ """
46
+ predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false)
47
+
48
+ Sample from the posterior predictive distribution by executing `model` with parameters fixed to each sample
49
+ in `chain`, and return the resulting `Chains`.
50
+
51
+ The `model` passed to `predict` is often different from the one used to generate `chain`.
52
+ Typically, the model from which `chain` originated treats certain variables as observed (i.e.,
53
+ data points), while the model you pass to `predict` may mark these same variables as missing
54
+ or unobserved. Calling `predict` then leverages the previously inferred parameter values to
55
+ simulate what new, unobserved data might look like, given your posterior beliefs.
56
+
57
+ For each parameter configuration in `chain`:
58
+ 1. All random variables present in `chain` are fixed to their sampled values.
59
+ 2. Any variables not included in `chain` are sampled from their prior distributions.
60
+
61
+ If `include_all` is `false`, the returned `Chains` will contain only those variables that were not fixed by
62
+ the samples in `chain`. This is useful when you want to sample only new variables from the posterior
63
+ predictive distribution.
64
+
65
+ # Examples
66
+ ```jldoctest
67
+ using AbstractMCMC, Distributions, DynamicPPL, Random
68
+
69
+ @model function linear_reg(x, y, σ = 0.1)
70
+ β ~ Normal(0, 1)
71
+ for i in eachindex(y)
72
+ y[i] ~ Normal(β * x[i], σ)
73
+ end
74
+ end
75
+
76
+ # Generate synthetic chain using known ground truth parameter
77
+ ground_truth_β = 2.0
78
+
79
+ # Create chain of samples from a normal distribution centered on ground truth
80
+ β_chain = MCMCChains.Chains(
81
+ rand(Normal(ground_truth_β, 0.002), 1000), [:β,]
82
+ )
83
+
84
+ # Generate predictions for two test points
85
+ xs_test = [10.1, 10.2]
86
+
87
+ m_train = linear_reg(xs_test, fill(missing, length(xs_test)))
88
+
89
+ predictions = DynamicPPL.AbstractPPL.predict(
90
+ Random.default_rng(), m_train, β_chain
91
+ )
92
+
93
+ ys_pred = vec(mean(Array(predictions); dims=1))
94
+
95
+ # Check if predictions match expected values within tolerance
96
+ (
97
+ isapprox(ys_pred[1], ground_truth_β * xs_test[1], atol = 0.01),
98
+ isapprox(ys_pred[2], ground_truth_β * xs_test[2], atol = 0.01)
99
+ )
100
+
101
+ # output
102
+
103
+ (true, true)
104
+ ```
105
+ """
106
+ function DynamicPPL. predict (
107
+ rng:: DynamicPPL.Random.AbstractRNG ,
108
+ model:: DynamicPPL.Model ,
109
+ chain:: MCMCChains.Chains ;
110
+ include_all= false ,
111
+ )
112
+ parameter_only_chain = MCMCChains. get_sections (chain, :parameters )
113
+ varinfo = DynamicPPL. VarInfo (model)
114
+
115
+ iters = Iterators. product (1 : size (chain, 1 ), 1 : size (chain, 3 ))
116
+ predictive_samples = map (iters) do (sample_idx, chain_idx)
117
+ DynamicPPL. setval_and_resample! (varinfo, parameter_only_chain, sample_idx, chain_idx)
118
+ model (rng, varinfo, DynamicPPL. SampleFromPrior ())
119
+
120
+ vals = DynamicPPL. values_as_in_model (model, varinfo)
121
+ varname_vals = mapreduce (
122
+ collect,
123
+ vcat,
124
+ map (DynamicPPL. varname_and_value_leaves, keys (vals), values (vals)),
125
+ )
126
+
127
+ return (varname_and_values= varname_vals, logp= DynamicPPL. getlogp (varinfo))
128
+ end
129
+
130
+ chain_result = reduce (
131
+ MCMCChains. chainscat,
132
+ [
133
+ _predictive_samples_to_chains (predictive_samples[:, chain_idx]) for
134
+ chain_idx in 1 : size (predictive_samples, 2 )
135
+ ],
136
+ )
137
+ parameter_names = if include_all
138
+ MCMCChains. names (chain_result, :parameters )
139
+ else
140
+ filter (
141
+ k -> ! (k in MCMCChains. names (parameter_only_chain, :parameters )),
142
+ names (chain_result, :parameters ),
143
+ )
144
+ end
145
+ return chain_result[parameter_names]
146
+ end
147
+
148
+ function _predictive_samples_to_arrays (predictive_samples)
149
+ variable_names_set = DynamicPPL. OrderedCollections. OrderedSet {DynamicPPL.VarName} ()
150
+
151
+ sample_dicts = map (predictive_samples) do sample
152
+ varname_value_pairs = sample. varname_and_values
153
+ varnames = map (first, varname_value_pairs)
154
+ values = map (last, varname_value_pairs)
155
+ for varname in varnames
156
+ push! (variable_names_set, varname)
157
+ end
158
+
159
+ return DynamicPPL. OrderedCollections. OrderedDict (zip (varnames, values))
160
+ end
161
+
162
+ variable_names = collect (variable_names_set)
163
+ variable_values = [
164
+ get (sample_dicts[i], key, missing ) for i in eachindex (sample_dicts),
165
+ key in variable_names
166
+ ]
167
+
168
+ return variable_names, variable_values
169
+ end
170
+
171
+ function _predictive_samples_to_chains (predictive_samples)
172
+ variable_names, variable_values = _predictive_samples_to_arrays (predictive_samples)
173
+ variable_names_symbols = map (Symbol, variable_names)
174
+
175
+ internal_parameters = [:lp ]
176
+ log_probabilities = reshape ([sample. logp for sample in predictive_samples], :, 1 )
177
+
178
+ parameter_names = [variable_names_symbols; internal_parameters]
179
+ parameter_values = hcat (variable_values, log_probabilities)
180
+ parameter_values = MCMCChains. concretize (parameter_values)
181
+
182
+ return MCMCChains. Chains (
183
+ parameter_values, parameter_names, (internals= internal_parameters,)
184
+ )
185
+ end
186
+
45
187
"""
46
188
returned(model::Model, chain::MCMCChains.Chains)
47
189
0 commit comments