-
Notifications
You must be signed in to change notification settings - Fork 32
Accumulators, stage 1 #885
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: breaking
Are you sure you want to change the base?
Conversation
* AbstractPPL 0.11; change prefixing behaviour * Use DynamicPPL.prefix rather than overloading
* Unify {Untyped,Typed}{Vector,}VarInfo constructors * Update invocations * NTVarInfo * Fix tests * More fixes * Fixes * Fixes * Fixes * Use lowercase functions, don't deprecate VarInfo * Rewrite VarInfo docstring * Fix methods * Fix methods (really)
Benchmark Report for Commit 0b08237Computer Information
Benchmark Results
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not reviewing actual code, just one high-level thought that struck me.
src/abstract_varinfo.jl
Outdated
function setlogp!!(vi::AbstractVarInfo, logp) | ||
vi = setlogprior!!(vi, zero(logp)) | ||
vi = setloglikelihood!!(vi, logp) | ||
return vi | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking about this the other day and thought I may as well post now. The ...logp()
family of functions are no longer well-defined in a world where everything is cleanly split into prior and likelihood. (only getlogp
and resetlogp
still make sense) I think last time we chatted about it the decision was to maybe forward the others to the likelihood methods, but I was wondering if it's actually safer to remove them (or make them error informatively) and force people to use likelihood or prior as otherwise it risks introducing subtle bugs. Backward compatibility is important but if it comes at the cost of correctness I feel kinda uneasy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My hope was that we could deprecate them but provide the same functionality through the new functions, like above. It's a good question as to whether there are edge cases where they do not provide the same functionality. I think this is helped by the fact that PriorContext and LikelihoodContext won't exist, and hence one can't be running code where the expectation would be that ...logp()
would be referring to logprior or loglikelihood in particular. And I think as long as one expects to get the logjoint out of ...logp()
we can do things like above, shoving things into likelihood, and get the same results. Do you think that solves it and let's us use deprecations rather than straight-up removals, or do you see other edge cases?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Something like this is a case where setlogp is ill-defined:
DynamicPPL.jl/src/test_utils/varinfo.jl
Lines 47 to 62 in c7bdc3f
lp = getlogp(vi_typed_metadata) | |
varinfos = map(( | |
vi_untyped_metadata, | |
vi_untyped_vnv, | |
vi_typed_metadata, | |
vi_typed_vnv, | |
svi_typed, | |
svi_untyped, | |
svi_vnv, | |
svi_typed_ref, | |
svi_untyped_ref, | |
svi_vnv_ref, | |
)) do vi | |
# Set them all to the same values. | |
DynamicPPL.setlogp!!(update_values!!(vi, example_values, varnames), lp) | |
end |
The logp here contains terms from both prior and likelihood, but after calling setlogp the prior would always be 0, which is inconsistent with the varinfo.
Of course, we can fix this on our end - we would get and set logprior and loglikelihood manually, and we can grep the codebase to make sure that there are no other ill-defined calls to setlogp. We can't guarantee that other people will be similarly careful, though (and us or anyone being careful also doesn't guarantee that everything will be fixed correctly).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While looking for other uses of setlogp, I encountered this:
AdvancedHMC.Transition
only contains a single notion of log density, so it's not obvious to me how we're going to extract the prior and likelihood components from it 😓 This might require upstream changes to AdvancedHMC. Since the contexts will be removed, I suspect LogDensityFunction
also needs to be changed so that there's a way for it to return only the prior or the likelihood (or maybe it should return both).
(For the record, I'd be quite happy with making all of these changes!)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logp here contains terms from both prior and likelihood, but after calling setlogp the prior would always be 0, which is inconsistent with the varinfo.
It is inconsistent, but as long as the user only uses getlogp
, they would never see the difference, right? If some of logprior is accidentally stored in loglikelihood or vice versa, as long as one is using getlogp
and DefaultContext
that should be undetectable. What would be trouble is if someone mixes using e.g. setlogp!!
and getlogprior
, which would require adding calls to getlogprior
after upgrading to a version that has deprecated setlogp!!
, but probably people would end up doing that. Maybe the deprecation warning could say something about this?
Since the contexts will be removed, I suspect LogDensityFunction also needs to be changed so that there's a way for it to return only the prior or the likelihood (or maybe it should return both).
Yeah, this sort of stuff will come up (and is coming up) in multiple places. Anything that explicitly uses PriorContext or LikelihoodContext would need to be changed to use LogPrior and LogLikelihood accumulators instead. I'm currently doing this for pointwiselogdensities
.
|
y = getindex_internal(vi, vn) | ||
f = from_maybe_linked_internal_transform(vi, vn, dist) | ||
x, logjac = with_logabsdet_jacobian(f, y) | ||
vi = accumulate_assume!!(vi, x, logjac, vn, dist) | ||
return x, vi |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How do we deal with tempering of logpdf and such now that it happens in the leaf of the call stack?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the past, we would do this by altering the logpdf
coming a the assume
higher up in the call tree
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What are the needs of tempering? What does it need to alter?
Co-authored-by: Tor Erlend Fjelde <[email protected]>
Pull Request Test Coverage Report for Build 14646630936Details
💛 - Coveralls |
@@ -1,11 +0,0 @@ | |||
@testset "Turing independence" begin |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This whole test seemed very redundant given the rest of the test suite.
docs/src/api.md
Outdated
getlogp | ||
setlogp!! | ||
acclogp!! | ||
getlogprior |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've implemented the NamedTuple versions of get/set/acclogp
and set setlogp!!(vi, x::Number)
to error and acclogp!!(vi, x::Number)
to fall back on accloglikelihood!!
with a deprecation warning.
This is ready for review. Apologies for the huge line count. I did try to not make edits that weren't core to the new features or required to make tests pass. Open questions:
Things I would like to get to, but not in this PR:
|
I don't know why the benchmarks are failing, they work for me locally. I'll look into it next week. Results on my laptop, this branch:
Current main:
I'll look into this a bit more, see what might be happening with LDA and if something could be done about the very smallest models taking a substantial hit. I also want to try making |
Yes, I never manage to remember what |
LogPrior | ||
LogLikelihood | ||
NumProduce |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For clarity
LogPrior | |
LogLikelihood | |
NumProduce | |
LogPriorAccumulator | |
LogLikelihoodAccumulator | |
NumProduceAccumulator |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, good idea. I had this on mind as well but then forgot about it. I may hold off implementing until a final decision has been made on the term "accumulator".
@@ -427,9 +450,6 @@ Contexts are subtypes of `AbstractPPL.AbstractContext`. | |||
```@docs | |||
SamplingContext | |||
DefaultContext |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For clarity
DefaultContext | |
AccumulatorContext |
- `PriorContext` and `LikelihoodContext` no longer exist. By default, a `VarInfo` tracks both the log prior and the log likelihood separately, and they can be accessed with `getlogprior` and `getloglikelihood`. If you want to execute a model while only accumulating one of the two (to save clock cycles), you can do so by creating a `VarInfo` that only has one accumulator in it, e.g. `varinfo = setaccs!!(varinfo, (LogPrior(),))`. | ||
- `MiniBatchContext` does not exist anymore. It can be replaced by creating and using a custom accumulator that replaces the default `LikelihoodContext`. We may introduce such an accumulator in DynamicPPL in the future, but for now you'll need to do it yourself. | ||
- `tilde_observe` and `observe` have been removed. `tilde_observe!!` still exists, and any contexts should modify its behaviour. We may further rework the call stack under `tilde_observe!!` in the near future. | ||
- `tilde_assume` no longer returns the log density of the current assumption as its second return value. We may further rework the `tilde_assume!!` call stack as well. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason why we can't remove tilde_assume
like tilde_observe
and observe
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's on my list to do, but I didn't want to put in the same PR since I didn't have to. (I had to do tilde_observe
because of complications with PointwiseLogdensityAccumulator
). Or, more precisely, what's on my list is to revisit the whole call stack for both tilde_assume!!
and tilde_observe!!
and see what the best way to do things is.
I'm mostly happy with what's been done now, but I am strongly in favour of going further. I think Markus and I discussed this on a call last week, but I'll explain my rationale again for the benefit of everybody reading.
Basically, I don't understand why we are committing to support something that we don't really support, it just means that the next time we visit the
I don't really mind either way. I think I have mentioned before that formally these are monoidal structures which would be the most precise term, but also rather impenetrable to someone who doesn't know it, so I'm quite happy to go with either Accumulator or State.
What is the motivation for changing I find
If it were called |
- `tilde_assume` no longer returns the log density of the current assumption as its second return value. We may further rework the `tilde_assume!!` call stack as well. | ||
- For literal observation statements like `0.0 ~ Normal(blahblah)` we used to call `tilde_observe!!` without the `vn` argument. This method no longer exists. Rather we call `tilde_observe!!` with `vn` set to `nothing`. | ||
- `set/reset/increment_num_produce!` have become `set/reset/increment_num_produce!!` (note the second exclamation mark). They are no longer guaranteed to modify the `VarInfo` in place, and one should always use the return value. | ||
- `@addlogprob!` now _always_ adds to the log likelihood. Previously it added to the log probability that the execution context specified, e.g. the log prior when using `PriorContext`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding to the log likelihood is a sensible default. For advanced use by programmable inference, let's also support
@addlogprob (logprior=0., loglikelihood=0.)
It doesn't help to remove
I think I see why @penelopeysm wants to avoid |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wouldn't even call this 'request changes', just a bunch of thoughts. I haven't reviewed all the code yet, only the interface bits (abstract_varinfo.jl and accumulator.jl - helpfully they're first alphabetically!) but I didn't want to put like 50 comments at once. Happy to continue once we've worked some of these out.
macro addlogprob!(ex) | ||
return quote | ||
$(esc(:(__varinfo__))) = acclogp!!( | ||
$(esc(:(__context__))), $(esc(:(__varinfo__))), $(esc(ex)) | ||
) | ||
if $hasacc($(esc(:(__varinfo__))), Val(:LogLikelihood)) | ||
$(esc(:(__varinfo__))) = $accloglikelihood!!($(esc(:(__varinfo__))), $(esc(ex))) | ||
end | ||
end | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This appears to silently do nothing if the varinfo has no likelihood accumulator. One might ask what should it do instead? The problem is that, it's simply not clear what addlogprob should do when it can't find a likelihood accumulator. Should it attempt to add to the prior instead? Should it silently do nothing? Should it error? Personally, I would argue that given that the expected behaviour is not clear, it should error.
It's precisely because of cases like these that I think these logp functions should just be removed -- they simply don't have well-defined behaviour.
But if we insist on having them, we should at least make it error when there's no loglikelihood. I guess the easiest way is to remove the hasacc
check.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The definition of addlogprob!
would, after this PR, be that it adds to the log likelihood. (It's in the docstring too.) So it should either silently do nothing or error, but not e.g. add to the log prior.
The reason I think they should silently do nothing is that
- it's in line with what
tilde_observe!!
statements do. They should add to the log likelihood, but if there's no accumulator for it they simply skip that step. - erroring would make it impossible to execute a model with
@addlogprob!
calls in it without accumulating log likelihood (which would then mean executing logpdf calls at everytilde_observe!!
). This goes against the "choose what you want to accumulate" idea that having the separate accumulators (andPriorContext
andLikelihoodContext
) is largely about.
macro addlogprob!(ex) | ||
return quote |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think at the very least we should provide @addloglikelihood!
and @addlogprior!
macros as well
# TODO(mhauru) We should write down the list of methods that any subtype of AbstractVarInfo | ||
# has to implement. Not sure what the full list is for parameters values, but for | ||
# accumulators we only need `getaccs` and `setaccs!!`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I put a list of AbstractVarInfo interface methods together the other day, feel free to ask me for it on Slack or something. We'd have to add the accumulator bits in, of course.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could be a good contribution to #899
Return a boolean for whether `vi` has an accumulator with name `accname`. | ||
""" | ||
hasacc(vi::AbstractVarInfo, accname::Val) = haskey(getaccs(vi), accname) | ||
function hassacc(vi::AbstractVarInfo, accname::Symbol) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
function hassacc(vi::AbstractVarInfo, accname::Symbol) | |
function hasacc(vi::AbstractVarInfo, accname::Symbol) |
function map_accumulator!!(func::Function, vi::AbstractVarInfo, accname::Symbol) | ||
return error( | ||
""" | ||
The method map_accumulator!!(func::Function, vi::AbstractVarInfo, accname::Symbol) | ||
does not exist. For type stability reasons use | ||
map_accumulator!!(func::Function, vi::AbstractVarInfo, ::Val{accname}) instead. | ||
""" | ||
) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these mostly reminders for ourselves, or is it meant to be user-facing?
If it's the former, I'd opine that we don't need these 'reminder methods' and we could just add a note to the docstring, because if you get a MethodError it'd be fairly easy to bring up the docstring and go 'ah, yes, I know what I need to do'.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are meant to be user-facing to the same degree as map_accumulator!!
is user-facing (which is not really user-facing, but e.g. inference-algorithm-developer-facing). The reason I put these in is that morally map_accumulator!!(func::Function, vi::AbstractVarInfo, accname::Symbol)
is exactly the right thing to call, and I might find it weird when I do the right thing to get a MethodError
. The fact that you should wrap the accname
argument in a Val
is annoying technicality, due to type stability, that no one should waste brain cycles trying to understand the deeper meaning of. I myself made the mistake of calling these with a raw Symbol multiple times in developing this code and would have appreciated this reminder when trying to understand why my tests fail.
if !( | ||
names == (:logprior, :loglikelihood) || | ||
names == (:loglikelihood, :logprior) || | ||
names == (:logprior,) || | ||
names == (:loglikelihood,) | ||
) | ||
error("logp must have fields logprior and/or loglikelihood and no other fields.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're lucky there are only two types of logp, eh...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not the prettiest part of the codebase, but it's explicit and easy on the compiler.
|
||
An accumulator is an object that may change its value at every tilde_assume!! or | ||
tilde_observe!! call based on the random variable in question. The obvious examples of | ||
accumulators are the log prior and log likelihood. Others examples might be a variable that |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
accumulators are the log prior and log likelihood. Others examples might be a variable that | |
accumulators are the log prior and log likelihood. Other examples might be a variable that |
return AccumulatorTuple(new_nt) | ||
end | ||
|
||
# END ACCUMULATOR TUPLE, BEGIN LOG PROB AND NUM PRODUCE ACCUMULATORS |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would you consider splitting the file up into two to make it clear where the interface ends and where implementations of them begin?
I've long thought that each context should be in its own file, just never got round to actually separating them. I feel we could start accumulators off on a good footing by doing this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yes, this was on my mind at some point but slipped. Would you keep AccumulatorTuple
in the same file as AbstractAccumulator
? I would, but not committed to it.
# Fields | ||
$(TYPEDFIELDS) | ||
""" | ||
struct LogPrior{T} <: AbstractAccumulator |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
julia> using DynamicPPL; DynamicPPL.LogPrior("wat")
LogPrior("wat")
Maybe a type bound would be useful to prevent me from doing silly things? 😄
struct LogPrior{T} <: AbstractAccumulator | |
struct LogPrior{T<:Real} <: AbstractAccumulator |
(Same for the remainder of the methods in this file)
split(::LogLikelihood{T}) where {T} = LogLikelihood(zero(T)) | ||
split(acc::NumProduce) = acc | ||
|
||
combine(acc::LogPrior, acc2::LogPrior) = LogPrior(acc.logp + acc2.logp) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
combine(acc::LogPrior, acc2::LogPrior) = LogPrior(acc.logp + acc2.logp) | |
combine(acc::LogPrior{T}, acc2::LogPrior{T}) where {T} = LogPrior(acc.logp + acc2.logp) |
I think it only makes sense to combine two LogPriors of the same type (.....?) If so, then this would make for a better error message. If not (i.e., if you think we need to add LogPriors of different types) then feel free to ignore
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unsure about this one. I wonder if there might be cases where e.g. on one thread some type conversion happens and you would rather either have it just work (if the two types can be added) or error with a message that comes from the underlying types. It does seem unlikely to run into this though.
Not what you asked but related: For +
I would definitely like to support mixed sums to the same degree that the underlying types support mixed sums. It is also needed for instance for adding scalars to AD tracker types.
Thanks both for the comments.
This isn't fully reflected in what's implemented in this PR, but I would like to get to a point where the meaning of If I actually manage to implement that, would you then be in favour of the renaming? I say "if", because there may be complications when interfacing with samplers that I don't see now, regarding what needs to happen before/after accumulation. |
This is starting to take shape. It's too early for a review: Everything is undocumented, uncleaned, and some things are still broken. The base design is there though, and most tests pass (pointwiseloglikelihood and doctests being the exceptions), so @penelopeysm, @torfjelde, if you want to have an early look at where this is going, feel free. The most interesting files are accumulators.jl, abstract_varinfo.jl, and context_implementations.jl.
In addition to obvious things that still need doing (documentation, clean-up, new tests, adding deprecations, fixing pointwiseloglikehood), a few things I have on my mind:
getacc
and similar functions should take the type of the accumulator as the index, or rather the symbol returned byaccumulator_name
. Leaning towards latter, but the former is what's currently implemented.DefaultContext
toAccumulationContext
. Or something else? I'm not fixated on the term "accumulator".(tilde_)assume
and(tilde_)observe
has changed (they no longer returnlogp
), the whole stack of calls withintilde_obssume!!
should be revisited. In particular, I'm thinking of splitting anything sampling-related to a call oftilde_obbsume
withSamplingContext
, that then at the end callstilde_obssume
withDefaultContext
. This might be a separate PR though.metadata.order
be an accumulator as well. Probably needs to actually be in the same accumulator withNumProduce
, since the two go together. Probably a separate PR though.