Open
Description
I've been having problems sampling models that use truncated normal distributions.
This minimal worked example from the discourse.julialang.org discussion shows that the initial vector passed in seems to be problematic. The first printout doesn't show the provided initial value, but rather a value with a very small standard deviation of the errors, therefore it immediately has numerical issues, showing zero probability to be at the initial point.
using Pkg
Pkg.activate(".")
using Turing, DataFrames,DataFramesMeta,LazyArrays,Distributions,DistributionsAD
using LazyArrays, ReverseDiff, Memoization
## every few hours a random staff member comes and gets a random
## patient to bring them outside to a garden through a door that has a
## scale. Sometimes using a wheelchair, sometimes not. knowing the
## total weight of the two people and the wheelchair plus some errors
## (from the scale measurements), infer the individual weights of all
## individuals and the weight of the wheelchair.
nstaff = 100
npat = 100
staffids = collect(1:nstaff)
patientids = collect(1:npat)
staffweights = rand(Normal(150,30),length(staffids))
patientweights = rand(Normal(150,30),length(staffids))
wheelchairwt = 15
nobs = 300
data = DataFrame(staff=rand(staffids,nobs),patient=rand(patientids,nobs))
data.usewch = rand(0:1,nobs)
data.totweights = [staffweights[data.staff[i]] + patientweights[data.patient[i]] for i in 1:nrow(data)] .+ data.usewch .* wheelchairwt .+ rand(Normal(0.0,20.0),nrow(data))
Turing.setadbackend(:reversediff)
Turing.setrdcache(true)
Turing.emptyrdcache()
@model function estweights(nstaff,staffid,npatients,patientid,usewch,totweight)
wcwt ~ Gamma(20.0,15.0/19)
staffweights ~ filldist(Normal(150,30),nstaff)
patientweights ~ filldist(Normal(150,30),npatients)
totweight ~ MvNormal(view(staffweights,staffid) .+ view(patientweights,patientid) .+ usewch .* wcwt,20.0)
end
@model function estweights2(nstaff,staffid,npatients,patientid,usewch,totweight)
wcwt ~ Gamma(20.0,15.0/19)
staffweights ~ filldist(truncated(Normal(150,30),90.0,Inf),nstaff)
patientweights ~ filldist(truncated(Normal(150,30),90.0,Inf),npatients)
totweight ~ arraydist([Gamma(15,(staffweights[staffid[i]] + patientweights[patientid[i]] + usewch[i] * wcwt)/14) for i in 1:length(totweight)])
end
@model function estweights3(nstaff,staffid,npatients,patientid,usewch,totweight)
wcwt ~ Gamma(20.0,15.0/19)
measerr ~ Gamma(10.0,20.0/9)
staffweights ~ filldist(truncated(Normal(150,30),90.0,Inf),nstaff)
patientweights ~ filldist(truncated(Normal(150,30),90.0,Inf),npatients)
totweight ~ arraydist([truncated(Normal(staffweights[staffid[i]] + patientweights[patientid[i]] + usewch[i] * wcwt, measerr),0.0,Inf) for i in 1:length(totweight)])
end
function truncatenormal(a,b)::UnivariateDistribution
truncated(Normal(a,b),0.0,Inf)
end
@model function estweights3lazy(nstaff,staffid,npatients,patientid,usewch,totweight)
wcwt ~ Gamma(20.0,15.0/19)
measerr ~ Gamma(10.0,20.0/9)
staffweights ~ filldist(truncated(Normal(150,30),90.0,Inf),nstaff)
patientweights ~ filldist(truncated(Normal(150,30),90.0,Inf),npatients)
theta = LazyArray(@~ view(staffweights,staffid) .+ view(patientweights,patientid) .+ usewch .* wcwt)
println("""Evaluating model...
wcwt: $wcwt
measerr: $measerr
exstaffweights: $(staffweights[1:10])
expatweights: $(patientweights[1:10])
""")
totweight ~ arraydist(LazyArray(@~ truncatenormal.(theta,measerr)))
end
@model function estweights4(nstaff,staffid,npatients,patientid,usewch,totweight)
wcwt ~ Gamma(20.0,15.0/19)
measerr ~ Gamma(10.0,20.0/9)
staffweights ~ filldist(truncated(Normal(150,30),90.0,Inf),nstaff)
patientweights ~ filldist(truncated(Normal(150,30),90.0,Inf),npatients)
means = view(staffweights,staffid) .+ view(patientweights,patientid) .+ usewch .* wcwt
totweight .~ Gamma.(12,means./11)
end
@model function estweightslazygamma(nstaff,staffid,npatients,patientid,usewch,totweight)
wcwt ~ Gamma(20.0,15.0/19)
measerr ~ Gamma(10.0,20.0/9)
staffweights ~ filldist(truncated(Normal(150,30),90.0,Inf),nstaff)
patientweights ~ filldist(truncated(Normal(150,30),90.0,Inf),npatients)
theta = LazyArray(@~ view(staffweights,staffid) .+ view(patientweights,patientid) .+ usewch .* wcwt)
totweight ~ arraydist(LazyArray(@~ Gamma.(15, theta ./ 14)))
end
# ch1 = sample(estweights(nstaff,data.staff,npat,data.patient,data.usewch,data.totweights),NUTS(500,.75),1000)
# ch2 = sample(estweights2(nstaff,data.staff,npat,data.patient,data.usewch,data.totweights),NUTS(500,.75),1000)
# ch3 = sample(estweights3(nstaff,data.staff,npat,data.patient,data.usewch,data.totweights),NUTS(500,.75),1000)
ch3l = sample(estweights3lazy(nstaff,data.staff,npat,data.patient,data.usewch,data.totweights),NUTS(500,.75,init_ϵ=.002),1000;
init_theta = vcat([15.0,20.0],staffweights,patientweights))
# ch4 = sample(estweights4(nstaff,data.staff,npat,data.patient,data.usewch,data.totweights),NUTS(500,.75),1000)
#ch5 = sample(estweightslazygamma(nstaff,data.staff,npat,data.patient,data.usewch,data.totweights),NUTS(500,.75),1000)
When running this version, the initial