Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Slice sampling as a Gibbs sampler #2300

Open
Red-Portal opened this issue Aug 11, 2024 · 5 comments
Open

Slice sampling as a Gibbs sampler #2300

Red-Portal opened this issue Aug 11, 2024 · 5 comments

Comments

@Red-Portal
Copy link
Member

Red-Portal commented Aug 11, 2024

Hi all,

I looked into using the samplers in SliceSampling as a component to Experimental.Gibbs. After a few patches to SliceSampling, it seems to be feasible. Here is a snippet that works with SliceSampling#main:

using Distributions
using Turing
using Turing.Experimental
using SliceSampling

@model function demo()
    s ~ InverseGamma(3, 3)
    m ~ MvNormal(zeros(10), sqrt(s))
end

Turing.Inference.getparams(::Turing.DynamicPPL.Model, state::SliceSampling.UnivariateSliceState) = state.transition.params
Turing.Inference.getparams(::Turing.DynamicPPL.Model, state::SliceSampling.GibbsState) = state.transition.params

Turing.Experimental.gibbs_requires_recompute_logprob(
    model_dst,
    ::DynamicPPL.Sampler{<:Turing.Inference.ExternalSampler},
    sampler_src,
    state_dst,
    state_src
) = false

n_samples = 10000
model     = demo()
sample(
    model,
    Experimental.Gibbs(
        (
            m = externalsampler(RandPermGibbs(SliceSteppingOut(0.1))),
            s = externalsampler(SliceSteppingOut(0.1)),
        ),
    ),
    n_samples
)
@Red-Portal
Copy link
Member Author

I think I will release the changes to SliceSampling.jl very soon. How should I proceed from here? Should I make a PR to the main repo adding a Sampler? Or should I keep everything in the Turing extension of SliceSampling.jl?

@sunxd3
Copy link
Collaborator

sunxd3 commented Aug 12, 2024

In the long run, we should just move some of these interface functions to AbstractMCMC (TuringLang/AbstractMCMC.jl#144). Then we only need update SliceSampling up to the AbstractMCMC interface.

For now, I think a package extension should work. @torfjelde, @yebai, @mhauru

@torfjelde
Copy link
Member

Awesome stuff @Red-Portal :) And yeah, extension for now, but we'll hopefully have this be part of the AstractMCMC.jl interface soon-ish 👍

@Red-Portal
Copy link
Member Author

Sounds great. I'll keep this issue open until then to keep track!

@Red-Portal
Copy link
Member Author

Red-Portal commented Aug 12, 2024

Turing.Experimental.Gibbs support is now officially coming with 0.6.0!

The results are very promising. Consider the following example on Turing's website:

using Distributions
using Turing
using SliceSampling

@model function simple_choice(xs)
    p ~ Beta(2, 2)
    z ~ Bernoulli(p)
    for i in 1:length(xs)
        if z == 1
            xs[i] ~ Normal(0, 1)
        else
            xs[i] ~ Normal(2, 1)
        end
    end
end
model = simple_choice([1.5, 2.0, 0.3])

sample(model, Gibbs(HMC(0.2, 3, :p), PG(20, :z)), 1000)
sample(model,  Turing.Experimental.Gibbs((p = externalsampler(SliceSteppingOut(2.0)), z = PG(20, :z))), n_samples)
sample(model,  Turing.Experimental.Gibbs((p = externalsampler(SliceDoublingOut(2.0)), z = PG(20, :z))), n_samples)

HMC:

Summary Statistics
  parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   ess_per_sec 
      Symbol   Float64   Float64   Float64    Float64    Float64   Float64       Float64 

           p    0.4259    0.1945    0.0176   119.6925   301.2020    1.0027        9.2327

Slice sampling with stepping-out:

Summary Statistics
  parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   ess_per_sec 
      Symbol   Float64   Float64   Float64    Float64    Float64   Float64       Float64 

           p    0.5936    0.1897    0.0062   916.9961   759.7287    1.0053       59.9814

Slice sampling with Doubling-out:

Summary Statistics
  parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   ess_per_sec 
      Symbol   Float64   Float64   Float64    Float64    Float64   Float64       Float64 

           p    0.5970    0.1919    0.0062   994.2301   786.9206    1.0044       67.0373

The results look very promising! Of course the performance of HMC can be improved by tuning the stepsize, but the point is that slice sampling provides very good performance with zero tuning.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants