Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update to tilde overloads in mh.jl #2360

Open
wants to merge 1 commit into
base: ch
Choose a base branch
from
Open

Update to tilde overloads in mh.jl #2360

wants to merge 1 commit into from

Conversation

torfjelde
Copy link
Member

@penelopeysm I was looking at your PR #2341 and found some bad bugs in the existing codebase for the MH, so I figured we should just get these fixed too.

Comment on lines +464 to +469
# Just defer to `SampleFromPrior`.
retval = DynamicPPL.dot_assume(rng, SampleFromPrior(), dist, vns[1], var, vi)
# Update the Gibbs IDs because they might have been assigned in the `SampleFromPrior` call.
DynamicPPL.updategid!.((vi,), vns, (spl,))
# Return.
return retval
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if we maybe should just move this "default" impl which uses SampleFromPrior + updategid! to DynamicPPL.jl itself. Thoughts @penelopeysm @mhauru ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this be useful for multiple samplers beyond MH?

Copy link

codecov bot commented Oct 4, 2024

Codecov Report

Attention: Patch coverage is 33.33333% with 6 lines in your changes missing coverage. Please review.

Project coverage is 84.36%. Comparing base (452d0d0) to head (679989a).

Files with missing lines Patch % Lines
src/mcmc/mh.jl 33.33% 6 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##               ch    #2360      +/-   ##
==========================================
+ Coverage   83.86%   84.36%   +0.50%     
==========================================
  Files          24       24              
  Lines        1580     1573       -7     
==========================================
+ Hits         1325     1327       +2     
+ Misses        255      246       -9     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@penelopeysm
Copy link
Member

penelopeysm commented Oct 4, 2024

It occurs to me that the value generated by DynamicPPL.assume(..., SampleFromPrior(), ...) is only used if a proposal isn't specified – if we have one then is this doing extra computation?

I don't know if there's an easy fix for this at all though.

@torfjelde
Copy link
Member Author

It occurs to me that the value generated by DynamicPPL.assume(..., SampleFromPrior(), ...) is only used if a proposal isn't specified – if we have one then is this doing extra computation?

SampleFromPrior just extracts the value from vi if it's "not to be sampled", so it shouldn't be doing anything extra here AFAIK. Can you maybe specify a bit what you mean here?

@penelopeysm
Copy link
Member

penelopeysm commented Oct 4, 2024

Oh, I see, I'm just confused. The values are stored in varinfo from the previous iteration when calculating the lp and so it just extracts those values. I had assumed it would call rand() on the prior, but now I observe that that's not the case ;)

edit; oh, I see it's this function that's doing it.

Turing.jl/src/mcmc/mh.jl

Lines 267 to 282 in 40a0d84

function LogDensityProblems.logdensity(f::MHLogDensityFunction, x::NamedTuple)
# TODO: Make this work with immutable `f.varinfo` too.
sampler = DynamicPPL.getsampler(f)
vi = f.varinfo
x_old, lj_old = vi[sampler], getlogp(vi)
set_namedtuple!(vi, x)
vi_new = last(DynamicPPL.evaluate!!(f.model, vi, DynamicPPL.getcontext(f)))
lj = getlogp(vi_new)
# Reset old `vi`.
setindex!!(vi, x_old, sampler)
setlogp!!(vi, lj_old)
return lj
end

@penelopeysm
Copy link
Member

Sorry, not trying to be annoying and get a final word in, but having spent a good amount of time figuring out the interplay between Turing and AdvancedMH I feel like this behaviour is a bit unexpected (and I actually now understand why it happens 😂):

@model function gdemo(x, y)
    s² ~ InverseGamma(2, 3)
    m ~ Normal(0, sqrt(s²))
    x ~ Normal(m, sqrt(s²))
    y ~ Normal(m, sqrt(s²))
end
chain = sample(
    gdemo(1.5, 2.0),
    MH(:m => AdvancedMH.RandomWalkProposal(Normal(0, 0.25))),
    10
)

Here is sampled once at the start of the chain and it's never touched again. I feel like people (perhaps just me 😄) would expect that it either defaults to a static proposal (the prior) or that it errors. If you agree I can open a separate issue :) this is obviously not part of this PR

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants