Skip to content

Difficulty sampling a model with truncated normal Likelihood #1722

Open
@dlakelan

Description

@dlakelan

I've been having problems sampling models that use truncated normal distributions.

This minimal worked example from the discourse.julialang.org discussion shows that the initial vector passed in seems to be problematic. The first printout doesn't show the provided initial value, but rather a value with a very small standard deviation of the errors, therefore it immediately has numerical issues, showing zero probability to be at the initial point.

https://discourse.julialang.org/t/making-turing-fast-with-large-numbers-of-parameters/69072/99?u=dlakelan

using Pkg
Pkg.activate(".")
using Turing, DataFrames,DataFramesMeta,LazyArrays,Distributions,DistributionsAD
using LazyArrays, ReverseDiff, Memoization

## every few hours a random staff member comes and gets a random
## patient to bring them outside to a garden through a door that has a
## scale. Sometimes using a wheelchair, sometimes not. knowing the
## total weight of the two people and the wheelchair plus some errors
## (from the scale measurements), infer the individual weights of all
## individuals and the weight of the wheelchair.

nstaff = 100
npat = 100
staffids = collect(1:nstaff)
patientids = collect(1:npat)
staffweights = rand(Normal(150,30),length(staffids))
patientweights = rand(Normal(150,30),length(staffids))
wheelchairwt = 15
nobs = 300

data = DataFrame(staff=rand(staffids,nobs),patient=rand(patientids,nobs))
data.usewch = rand(0:1,nobs)
data.totweights = [staffweights[data.staff[i]] + patientweights[data.patient[i]] for i in 1:nrow(data)] .+ data.usewch .* wheelchairwt .+ rand(Normal(0.0,20.0),nrow(data))


Turing.setadbackend(:reversediff)
Turing.setrdcache(true)
Turing.emptyrdcache()



@model function estweights(nstaff,staffid,npatients,patientid,usewch,totweight)
    wcwt ~ Gamma(20.0,15.0/19)
    staffweights ~ filldist(Normal(150,30),nstaff)
    patientweights ~ filldist(Normal(150,30),npatients)
    
    totweight ~ MvNormal(view(staffweights,staffid) .+ view(patientweights,patientid) .+ usewch .* wcwt,20.0)
end



@model function estweights2(nstaff,staffid,npatients,patientid,usewch,totweight)
    wcwt ~ Gamma(20.0,15.0/19)
    staffweights ~ filldist(truncated(Normal(150,30),90.0,Inf),nstaff)
    patientweights ~ filldist(truncated(Normal(150,30),90.0,Inf),npatients)
    
    totweight ~ arraydist([Gamma(15,(staffweights[staffid[i]] + patientweights[patientid[i]] + usewch[i] * wcwt)/14) for i in 1:length(totweight)])
end


@model function estweights3(nstaff,staffid,npatients,patientid,usewch,totweight)
    wcwt ~ Gamma(20.0,15.0/19)
    measerr ~ Gamma(10.0,20.0/9)
    staffweights ~ filldist(truncated(Normal(150,30),90.0,Inf),nstaff)
    patientweights ~ filldist(truncated(Normal(150,30),90.0,Inf),npatients)
    
    totweight ~ arraydist([truncated(Normal(staffweights[staffid[i]] + patientweights[patientid[i]] + usewch[i] * wcwt, measerr),0.0,Inf) for i in 1:length(totweight)])
end

function truncatenormal(a,b)::UnivariateDistribution
    truncated(Normal(a,b),0.0,Inf)
end


@model function estweights3lazy(nstaff,staffid,npatients,patientid,usewch,totweight)

    wcwt ~ Gamma(20.0,15.0/19)
    measerr ~ Gamma(10.0,20.0/9)
    staffweights ~ filldist(truncated(Normal(150,30),90.0,Inf),nstaff)
    patientweights ~ filldist(truncated(Normal(150,30),90.0,Inf),npatients)
    theta = LazyArray(@~ view(staffweights,staffid) .+ view(patientweights,patientid) .+ usewch .* wcwt)
    println("""Evaluating model... 
wcwt: $wcwt
measerr: $measerr
exstaffweights: $(staffweights[1:10])
expatweights: $(patientweights[1:10])
""")
    
    totweight ~ arraydist(LazyArray(@~ truncatenormal.(theta,measerr)))
end



@model function estweights4(nstaff,staffid,npatients,patientid,usewch,totweight)
    wcwt ~ Gamma(20.0,15.0/19)
    measerr ~ Gamma(10.0,20.0/9)
    staffweights ~ filldist(truncated(Normal(150,30),90.0,Inf),nstaff)
    patientweights ~ filldist(truncated(Normal(150,30),90.0,Inf),npatients)
    means = view(staffweights,staffid) .+ view(patientweights,patientid) .+ usewch .* wcwt
    totweight .~ Gamma.(12,means./11)
end




@model function estweightslazygamma(nstaff,staffid,npatients,patientid,usewch,totweight)
    wcwt ~ Gamma(20.0,15.0/19)
    measerr ~ Gamma(10.0,20.0/9)
    staffweights ~ filldist(truncated(Normal(150,30),90.0,Inf),nstaff)
    patientweights ~ filldist(truncated(Normal(150,30),90.0,Inf),npatients)
    theta = LazyArray(@~ view(staffweights,staffid) .+ view(patientweights,patientid) .+ usewch .* wcwt)
    totweight ~ arraydist(LazyArray(@~ Gamma.(15, theta ./ 14)))
end




# ch1 = sample(estweights(nstaff,data.staff,npat,data.patient,data.usewch,data.totweights),NUTS(500,.75),1000)


# ch2 = sample(estweights2(nstaff,data.staff,npat,data.patient,data.usewch,data.totweights),NUTS(500,.75),1000)

# ch3 = sample(estweights3(nstaff,data.staff,npat,data.patient,data.usewch,data.totweights),NUTS(500,.75),1000)

ch3l = sample(estweights3lazy(nstaff,data.staff,npat,data.patient,data.usewch,data.totweights),NUTS(500,.75,init_ϵ=.002),1000;
              init_theta = vcat([15.0,20.0],staffweights,patientweights))

# ch4 = sample(estweights4(nstaff,data.staff,npat,data.patient,data.usewch,data.totweights),NUTS(500,.75),1000)


#ch5 = sample(estweightslazygamma(nstaff,data.staff,npat,data.patient,data.usewch,data.totweights),NUTS(500,.75),1000)

When running this version, the initial

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions