Description
Hi everyone,
Problem
There seems to be a problem with using predict
in conjunction with models that use vectorisation.
Consider this fairly simple example:
We can generate a dataset by sampling from this model for (say)
= 11,...,20$ (only information propagated forward is about variance of noise).
However, this fails as per below:
using Turing, StatsPlots, DynamicPPL, Random
Random.seed!(1234)
@model function mv_normal(n)
σ ~ truncated(Normal(0., 0.1), lower = 0.)
μ ~ MvNormal(n, 1.0) # Means
x ~ MvNormal(μ, σ) # noise
return x
end
mdl_10 = mv_normal(10)
# Sample data
x_data = mdl_10()
# infer means and obs noise
chn = sample(mdl_10 | (x = x_data,), NUTS(), 2_000)
# forecast
forecast_mdl = mv_normal(20)
forecast_chn = predict(forecast_mdl, chn; include_all = true)
let
obs = generated_quantities(forecast_mdl, forecast_chn) |> X -> reduce(hcat, X)
plt = plot(obs, c = :grey, alpha = 0.05, lab = "")
scatter!(plt, x_data, c = :red, lab = "observed", title = "BAD FORECAST", ms = 6)
end
The failure mode here seems to be that the sample underlying random variables for chn
.
Fix 1: mapreduce
across forecast
calls
So if you instead loop over samples and run forecast
for each sample, this seems to work:
forecast_chn_mapreduce = mapreduce(vcat, 1:size(chn, 1)) do i
c = predict(forecast_mdl, chn[i,:,1]; include_all = true)
# Take care to set the range sequentially
setrange(c, i:i)
end
let
obs = generated_quantities(forecast_mdl, forecast_chn_mapreduce) |> X -> reduce(hcat, X)
plt = plot(obs, c = :grey, alpha = 0.05, lab = "")
scatter!(plt, x_data, c = :black, lab = "observed", title = "OK FORECAST")
end
Fix 2: Non-vectorised sampling
Or you can modify the underlying model to not use calls to vectorised random variables (although IMO this is non-ideal).
@model function mv_normal_2(n)
σ ~ truncated(Normal(0., 0.1), lower = 0.)
μ = Vector{eltype(σ)}(undef, n)
for i = 1:n
μ[i] ~ Normal()
end
x ~ MvNormal(μ, σ) # noise
return x
end
mdl2_10 = mv_normal_2(10)
x_data2 = mdl2_10()
chn2 = sample(mdl2_10 | (x = x_data2,), NUTS(), 2_000)
forecast_mdl2 = mv_normal_2(20)
forecast_chn2 = predict(forecast_mdl2, chn2; include_all = true)
let
obs = generated_quantities(forecast_mdl2, forecast_chn2) |> X -> reduce(hcat, X)
plt = plot(obs, c = :grey, alpha = 0.05, lab = "")
scatter!(plt, x_data2, c = :black, lab = "observed", title = "ALSO OK FORECAST?")
end
Ideal situation
Obviously, it would be ideal if predict
"just worked" with vectorised random variables. Given the failure mode of naive usage of predict
I'm assuming that this is a problem with how the random numbers are generated around here?