Skip to content

Accumulators, stage 1 #885

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 53 commits into
base: breaking
Choose a base branch
from
Open

Accumulators, stage 1 #885

wants to merge 53 commits into from

Conversation

mhauru
Copy link
Member

@mhauru mhauru commented Apr 10, 2025

This is starting to take shape. It's too early for a review: Everything is undocumented, uncleaned, and some things are still broken. The base design is there though, and most tests pass (pointwiseloglikelihood and doctests being the exceptions), so @penelopeysm, @torfjelde, if you want to have an early look at where this is going, feel free. The most interesting files are accumulators.jl, abstract_varinfo.jl, and context_implementations.jl.

In addition to obvious things that still need doing (documentation, clean-up, new tests, adding deprecations, fixing pointwiseloglikehood), a few things I have on my mind:

  • Need to decide whether to keep the LogLikelihood and LogPrior accumulators immutable like they are now.
  • Whether getacc and similar functions should take the type of the accumulator as the index, or rather the symbol returned by accumulator_name. Leaning towards latter, but the former is what's currently implemented.
  • Maybe rename DefaultContext to AccumulationContext. Or something else? I'm not fixated on the term "accumulator".
  • Since the signature of (tilde_)assume and (tilde_)observe has changed (they no longer return logp), the whole stack of calls within tilde_obssume!! should be revisited. In particular, I'm thinking of splitting anything sampling-related to a call of tilde_obbsume with SamplingContext, that then at the end calls tilde_obssume with DefaultContext. This might be a separate PR though.
  • Benchmark
  • There are a few places where we are now unnecessarily accumulating all of log prior, log likelihood, and num produce. I should clean these up to benefit from being able to do one but not the others.
  • Make metadata.order be an accumulator as well. Probably needs to actually be in the same accumulator with NumProduce, since the two go together. Probably a separate PR though.

penelopeysm and others added 9 commits March 5, 2025 10:34
* AbstractPPL 0.11; change prefixing behaviour

* Use DynamicPPL.prefix rather than overloading
* Unify {Untyped,Typed}{Vector,}VarInfo constructors

* Update invocations

* NTVarInfo

* Fix tests

* More fixes

* Fixes

* Fixes

* Fixes

* Use lowercase functions, don't deprecate VarInfo

* Rewrite VarInfo docstring

* Fix methods

* Fix methods (really)
Copy link
Contributor

github-actions bot commented Apr 10, 2025

Benchmark Report for Commit 0b08237

Computer Information

Julia Version 1.11.5
Commit 760b2e5b739 (2025-04-14 06:53 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 7763 64-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)

Benchmark Results

|                 Model | Dimension |  AD Backend |      VarInfo Type | Linked | Eval Time / Ref Time | AD Time / Eval Time |
|-----------------------|-----------|-------------|-------------------|--------|----------------------|---------------------|
| Simple assume observe |         1 | forwarddiff |             typed |  false |                 69.8 |                 1.2 |
|           Smorgasbord |       201 | forwarddiff |             typed |  false |                921.6 |                30.2 |
|           Smorgasbord |       201 | forwarddiff | simple_namedtuple |   true |                483.7 |                44.2 |
|           Smorgasbord |       201 | forwarddiff |           untyped |   true |               1361.5 |                26.5 |
|           Smorgasbord |       201 | forwarddiff |       simple_dict |   true |               7071.4 |                22.2 |
|           Smorgasbord |       201 | reversediff |             typed |   true |               1639.6 |                25.3 |
|           Smorgasbord |       201 |    mooncake |             typed |   true |               1082.3 |                 5.8 |
|    Loop univariate 1k |      1000 |    mooncake |             typed |   true |               6031.8 |                 4.2 |
|       Multivariate 1k |      1000 |    mooncake |             typed |   true |               1110.4 |                 9.4 |
|   Loop univariate 10k |     10000 |    mooncake |             typed |   true |              66543.4 |                 3.6 |
|      Multivariate 10k |     10000 |    mooncake |             typed |   true |               9018.3 |                 9.8 |
|               Dynamic |        10 |    mooncake |             typed |   true |                205.3 |                21.5 |
|              Submodel |         1 |    mooncake |             typed |   true |                 73.1 |                13.6 |
|                   LDA |        12 | reversediff |             typed |   true |               1283.9 |                 1.8 |

Copy link
Member

@penelopeysm penelopeysm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not reviewing actual code, just one high-level thought that struck me.

Comment on lines 121 to 125
function setlogp!!(vi::AbstractVarInfo, logp)
vi = setlogprior!!(vi, zero(logp))
vi = setloglikelihood!!(vi, logp)
return vi
end
Copy link
Member

@penelopeysm penelopeysm Apr 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking about this the other day and thought I may as well post now. The ...logp() family of functions are no longer well-defined in a world where everything is cleanly split into prior and likelihood. (only getlogp and resetlogp still make sense) I think last time we chatted about it the decision was to maybe forward the others to the likelihood methods, but I was wondering if it's actually safer to remove them (or make them error informatively) and force people to use likelihood or prior as otherwise it risks introducing subtle bugs. Backward compatibility is important but if it comes at the cost of correctness I feel kinda uneasy.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My hope was that we could deprecate them but provide the same functionality through the new functions, like above. It's a good question as to whether there are edge cases where they do not provide the same functionality. I think this is helped by the fact that PriorContext and LikelihoodContext won't exist, and hence one can't be running code where the expectation would be that ...logp() would be referring to logprior or loglikelihood in particular. And I think as long as one expects to get the logjoint out of ...logp() we can do things like above, shoving things into likelihood, and get the same results. Do you think that solves it and let's us use deprecations rather than straight-up removals, or do you see other edge cases?

Copy link
Member

@penelopeysm penelopeysm Apr 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something like this is a case where setlogp is ill-defined:

lp = getlogp(vi_typed_metadata)
varinfos = map((
vi_untyped_metadata,
vi_untyped_vnv,
vi_typed_metadata,
vi_typed_vnv,
svi_typed,
svi_untyped,
svi_vnv,
svi_typed_ref,
svi_untyped_ref,
svi_vnv_ref,
)) do vi
# Set them all to the same values.
DynamicPPL.setlogp!!(update_values!!(vi, example_values, varnames), lp)
end

The logp here contains terms from both prior and likelihood, but after calling setlogp the prior would always be 0, which is inconsistent with the varinfo.

Of course, we can fix this on our end - we would get and set logprior and loglikelihood manually, and we can grep the codebase to make sure that there are no other ill-defined calls to setlogp. We can't guarantee that other people will be similarly careful, though (and us or anyone being careful also doesn't guarantee that everything will be fixed correctly).

Copy link
Member

@penelopeysm penelopeysm Apr 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While looking for other uses of setlogp, I encountered this:

https://github.com/TuringLang/Turing.jl/blob/fc32e10bc17ae3fda4d7e825b6fde45dc7bdb179/src/mcmc/hmc.jl#L201-L234

AdvancedHMC.Transition only contains a single notion of log density, so it's not obvious to me how we're going to extract the prior and likelihood components from it 😓 This might require upstream changes to AdvancedHMC. Since the contexts will be removed, I suspect LogDensityFunction also needs to be changed so that there's a way for it to return only the prior or the likelihood (or maybe it should return both).

(For the record, I'd be quite happy with making all of these changes!)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logp here contains terms from both prior and likelihood, but after calling setlogp the prior would always be 0, which is inconsistent with the varinfo.

It is inconsistent, but as long as the user only uses getlogp, they would never see the difference, right? If some of logprior is accidentally stored in loglikelihood or vice versa, as long as one is using getlogp and DefaultContext that should be undetectable. What would be trouble is if someone mixes using e.g. setlogp!! and getlogprior, which would require adding calls to getlogprior after upgrading to a version that has deprecated setlogp!!, but probably people would end up doing that. Maybe the deprecation warning could say something about this?

Since the contexts will be removed, I suspect LogDensityFunction also needs to be changed so that there's a way for it to return only the prior or the likelihood (or maybe it should return both).

Yeah, this sort of stuff will come up (and is coming up) in multiple places. Anything that explicitly uses PriorContext or LikelihoodContext would need to be changed to use LogPrior and LogLikelihood accumulators instead. I'm currently doing this for pointwiselogdensities.

@mhauru
Copy link
Member Author

mhauru commented Apr 11, 2025

pointwise_logdensities now works and uses its own accumulator type rather than a context. This leaves only a handful of tests failing, for quite trivial reasons. Plenty of clean-up to do though: In fixing pointwise_logprobability I had to make substantial changes to the tilde_observe pipeline, because accumulate_observe needed to get the varname as an argument and thus had to be moved higher in the call stack. I'll have to see how to best reorganise tilde_observe in such a way that making ParticleGibbs work with it wouldn't be horrible.

Comment on lines +129 to +133
y = getindex_internal(vi, vn)
f = from_maybe_linked_internal_transform(vi, vn, dist)
x, logjac = with_logabsdet_jacobian(f, y)
vi = accumulate_assume!!(vi, x, logjac, vn, dist)
return x, vi
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do we deal with tempering of logpdf and such now that it happens in the leaf of the call stack?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the past, we would do this by altering the logpdf coming a the assume higher up in the call tree

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are the needs of tempering? What does it need to alter?

@mhauru mhauru changed the base branch from main to breaking April 24, 2025 16:50
@coveralls
Copy link

Pull Request Test Coverage Report for Build 14646630936

Details

  • 306 of 451 (67.85%) changed or added relevant lines in 18 files are covered.
  • No unchanged relevant lines lost coverage.
  • Overall first build on mhauru/custom-accumulators at 55.812%

Changes Missing Coverage Covered Lines Changed/Added Lines %
ext/DynamicPPLMCMCChainsExt.jl 0 1 0.0%
src/test_utils/contexts.jl 0 1 0.0%
src/test_utils/models.jl 14 15 93.33%
src/transforming.jl 18 19 94.74%
src/values_as_in_model.jl 2 4 50.0%
src/context_implementations.jl 24 27 88.89%
src/logdensityfunction.jl 7 10 70.0%
src/pointwise_logdensities.jl 39 44 88.64%
src/debug_utils.jl 11 18 61.11%
src/accumulators.jl 64 72 88.89%
Totals Coverage Status
Change from base Build 14646769416: 55.8%
Covered Lines: 2228
Relevant Lines: 3992

💛 - Coveralls

@@ -1,11 +0,0 @@
@testset "Turing independence" begin
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This whole test seemed very redundant given the rest of the test suite.

docs/src/api.md Outdated
getlogp
setlogp!!
acclogp!!
getlogprior
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've implemented the NamedTuple versions of get/set/acclogp and set setlogp!!(vi, x::Number) to error and acclogp!!(vi, x::Number) to fall back on accloglikelihood!! with a deprecation warning.

@mhauru
Copy link
Member Author

mhauru commented Apr 25, 2025

This is ready for review. Apologies for the huge line count. I did try to not make edits that weren't core to the new features or required to make tests pass.

Open questions:

  • The deprecations/changes to/removals of addlogprob!, setlogp!!, acclogp!!, and getlogp discussed in two threads above.
  • Is "accumulator" a good name? In the issue @yebai wrote: "InferenceState / InferenceAccumulator might be more precise and readable than Accumulators. However, Accumulator is also a lovely name." State is in some sense more accurate, because accumulators are free to change in any way they want at every ~ statement. I worry a bit that "InferenceState" could be interpreted to mean many other things too. Ideas/thoughts welcome.
  • Should DefaultContext be renamed, to AccumulationContext or something else?

Things I would like to get to, but not in this PR:

  • Replace more contexts with accumulators.
  • Revisit the whole tilde_observe!!/tilde_assume!! stack. I made some edits here that were necessary to get pointwise_logdensities implemented with accumulators, but no more. This should probably be done while simultaneously working on Turing.jl compatibility.
  • Move NumProduce to Turing.jl's AdvancedPS wrapper.
  • See how close this brings SimpleVarInfo to having feature parity with TypedVarInfo, and what simplifications open up there.

@mhauru mhauru marked this pull request as ready for review April 25, 2025 16:59
@mhauru mhauru requested a review from penelopeysm April 25, 2025 16:59
@mhauru
Copy link
Member Author

mhauru commented Apr 25, 2025

I don't know why the benchmarks are failing, they work for me locally. I'll look into it next week.

Results on my laptop, this branch:

|                 Model | Dimension |  AD Backend |      VarInfo Type | Linked | Eval Time / Ref Time | AD Time / Eval Time |
|-----------------------|-----------|-------------|-------------------|--------|----------------------|---------------------|
| Simple assume observe |         1 | forwarddiff |             typed |  false |                 45.8 |                 1.2 |
|           Smorgasbord |       201 | forwarddiff |             typed |  false |                631.4 |                39.1 |
|           Smorgasbord |       201 | forwarddiff | simple_namedtuple |   true |                274.7 |                67.9 |
|           Smorgasbord |       201 | forwarddiff |           untyped |   true |               1052.8 |                29.7 |
|           Smorgasbord |       201 | forwarddiff |       simple_dict |   true |               5543.6 |                17.7 |
|           Smorgasbord |       201 | reversediff |             typed |   true |                966.0 |                23.1 |
|           Smorgasbord |       201 |    mooncake |             typed |   true |               1037.1 |                 5.3 |
|    Loop univariate 1k |      1000 |    mooncake |             typed |   true |               5662.0 |                 5.0 |
|       Multivariate 1k |      1000 |    mooncake |             typed |   true |                779.8 |                 8.4 |
|   Loop univariate 10k |     10000 |    mooncake |             typed |   true |              58560.3 |                 5.0 |
|      Multivariate 10k |     10000 |    mooncake |             typed |   true |               6468.6 |                 9.1 |
|               Dynamic |        10 |    mooncake |             typed |   true |                216.2 |                 9.3 |
|              Submodel |         1 |    mooncake |             typed |   true |                 63.2 |                12.8 |
|                   LDA |        12 | reversediff |             typed |   true |                937.6 |                 1.5 |

Current main:

|                 Model | Dimension |  AD Backend |      VarInfo Type | Linked | Eval Time / Ref Time | AD Time / Eval Time |
|-----------------------|-----------|-------------|-------------------|--------|----------------------|---------------------|
| Simple assume observe |         1 | forwarddiff |             typed |  false |                 12.6 |                 1.6 |
|           Smorgasbord |       201 | forwarddiff |             typed |  false |                737.1 |                36.2 |
|           Smorgasbord |       201 | forwarddiff | simple_namedtuple |   true |                244.6 |                74.2 |
|           Smorgasbord |       201 | forwarddiff |           untyped |   true |               1161.7 |                30.6 |
|           Smorgasbord |       201 | forwarddiff |       simple_dict |   true |               3284.7 |                22.4 |
|           Smorgasbord |       201 | reversediff |             typed |   true |               1108.1 |                22.0 |
|           Smorgasbord |       201 |    mooncake |             typed |   true |               1157.0 |                 3.6 |
|    Loop univariate 1k |      1000 |    mooncake |             typed |   true |               5365.1 |                 5.0 |
|       Multivariate 1k |      1000 |    mooncake |             typed |   true |                749.8 |                 7.9 |
|   Loop univariate 10k |     10000 |    mooncake |             typed |   true |              56613.4 |                 5.0 |
|      Multivariate 10k |     10000 |    mooncake |             typed |   true |               6427.4 |                 8.8 |
|               Dynamic |        10 |    mooncake |             typed |   true |                145.2 |                 7.7 |
|              Submodel |         1 |    mooncake |             typed |   true |                 17.4 |                 7.8 |
|                   LDA |        12 | reversediff |             typed |   true |                460.9 |                 2.8 |

I'll look into this a bit more, see what might be happening with LDA and if something could be done about the very smallest models taking a substantial hit. I also want to try making LogPrior, LogLikelihood, and NumProduce mutable and see what that does.

@yebai
Copy link
Member

yebai commented Apr 25, 2025

Should DefaultContext be renamed, to AccumulationContext or something else?

Yes, I never manage to remember what DefaultContext does.

Comment on lines +359 to +361
LogPrior
LogLikelihood
NumProduce
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For clarity

Suggested change
LogPrior
LogLikelihood
NumProduce
LogPriorAccumulator
LogLikelihoodAccumulator
NumProduceAccumulator

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, good idea. I had this on mind as well but then forgot about it. I may hold off implementing until a final decision has been made on the term "accumulator".

@@ -427,9 +450,6 @@ Contexts are subtypes of `AbstractPPL.AbstractContext`.
```@docs
SamplingContext
DefaultContext
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For clarity

Suggested change
DefaultContext
AccumulatorContext

- `PriorContext` and `LikelihoodContext` no longer exist. By default, a `VarInfo` tracks both the log prior and the log likelihood separately, and they can be accessed with `getlogprior` and `getloglikelihood`. If you want to execute a model while only accumulating one of the two (to save clock cycles), you can do so by creating a `VarInfo` that only has one accumulator in it, e.g. `varinfo = setaccs!!(varinfo, (LogPrior(),))`.
- `MiniBatchContext` does not exist anymore. It can be replaced by creating and using a custom accumulator that replaces the default `LikelihoodContext`. We may introduce such an accumulator in DynamicPPL in the future, but for now you'll need to do it yourself.
- `tilde_observe` and `observe` have been removed. `tilde_observe!!` still exists, and any contexts should modify its behaviour. We may further rework the call stack under `tilde_observe!!` in the near future.
- `tilde_assume` no longer returns the log density of the current assumption as its second return value. We may further rework the `tilde_assume!!` call stack as well.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason why we can't remove tilde_assume like tilde_observe and observe?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's on my list to do, but I didn't want to put in the same PR since I didn't have to. (I had to do tilde_observe because of complications with PointwiseLogdensityAccumulator). Or, more precisely, what's on my list is to revisit the whole call stack for both tilde_assume!! and tilde_observe!! and see what the best way to do things is.

@penelopeysm
Copy link
Member

penelopeysm commented Apr 27, 2025

The deprecations/changes to/removals of addlogprob!, setlogp!!, acclogp!!, and getlogp discussed in two threads above.

I'm mostly happy with what's been done now, but I am strongly in favour of going further. I think acclogp!!(::Float64) should be removed entirely and @addlogprob! should at least be deprecated (if not removed).

Markus and I discussed this on a call last week, but I'll explain my rationale again for the benefit of everybody reading.

  1. I consider correctness to be, by far, the most important part of the codebase. Part of how we ensure correctness is by codifying the knowledge that logp consists of prior and likelihood components, which is what this very PR seeks to do by removing the logp field and replacing it with separate accumulators. While the NamedTuple versions of getlogp and setlogp at least adhere to this notion that there are two different components, anything that takes a single Float doesn't. I therefore find it, at least from a philosophical perspective, inconsistent to retain the old interface - it is a halfway, noncommittal solution.

  2. The cost of adjusting to the new patch, for the vast majority of people, is a trivial substitution of @addlogprob! with @addloglikelihood! everywhere. Any cases where this is not a trivial fix would have involved something weird, like the PriorContext check, and is already broken by this PR.

Basically, I don't understand why we are committing to support something that we don't really support, it just means that the next time we visit the acclogp!!(::Float64) line of code we'll have to debate its semantics all over again, or worse still track down some bug that relates to it.

Is "accumulator" a good name?

I don't really mind either way. I think I have mentioned before that formally these are monoidal structures which would be the most precise term, but also rather impenetrable to someone who doesn't know it, so I'm quite happy to go with either Accumulator or State.

Should DefaultContext be renamed, to AccumulationContext or something else?

What is the motivation for changing DefaultContext?

I find AccumulationContext to be a misleading name:

  1. Calling it AccumulationContext implies that other contexts don't have accumulators, which is false.
  2. Accumulators and contexts are supposed to be entirely orthogonal concepts, and AccumulatorContext conflates the two. To use some slight hyperbole, you probably wouldn't be in favour of calling it MetadataContext, because the context has nothing to do with the Metadata. The same principle applies to accumulators, IMO.

If it were called JointContext, and we wanted to rename it away from that given that LikelihoodContext and PriorContext were removed, then I'd understand more - but I do think DefaultContext is quite a sensible name for a context that doesn't do anything special.

- `tilde_assume` no longer returns the log density of the current assumption as its second return value. We may further rework the `tilde_assume!!` call stack as well.
- For literal observation statements like `0.0 ~ Normal(blahblah)` we used to call `tilde_observe!!` without the `vn` argument. This method no longer exists. Rather we call `tilde_observe!!` with `vn` set to `nothing`.
- `set/reset/increment_num_produce!` have become `set/reset/increment_num_produce!!` (note the second exclamation mark). They are no longer guaranteed to modify the `VarInfo` in place, and one should always use the return value.
- `@addlogprob!` now _always_ adds to the log likelihood. Previously it added to the log probability that the execution context specified, e.g. the log prior when using `PriorContext`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding to the log likelihood is a sensible default. For advanced use by programmable inference, let's also support

@addlogprob (logprior=0., loglikelihood=0.)

@yebai
Copy link
Member

yebai commented Apr 27, 2025

@addlogprob! should at least be deprecated (if not removed).

It doesn't help to remove @addlogprob!. It doesn't cause any confusion or correctness issue from a statistical perspective. I suggest we keep it, but add support for a NamedTuple of log prior and log likelihoods.

I find AccumulationContext to be a misleading name:

I think I see why @penelopeysm wants to avoid AccumulationContext and tend to agree with her. However, the name DefaultContext is mysterious, and as I explained above, I have never been able to remember what DefaultContext means. Of course, we can add clear documentation or rename it to accurately reflect its current functionality.

Copy link
Member

@penelopeysm penelopeysm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't even call this 'request changes', just a bunch of thoughts. I haven't reviewed all the code yet, only the interface bits (abstract_varinfo.jl and accumulator.jl - helpfully they're first alphabetically!) but I didn't want to put like 50 comments at once. Happy to continue once we've worked some of these out.

Comment on lines 59 to 65
macro addlogprob!(ex)
return quote
$(esc(:(__varinfo__))) = acclogp!!(
$(esc(:(__context__))), $(esc(:(__varinfo__))), $(esc(ex))
)
if $hasacc($(esc(:(__varinfo__))), Val(:LogLikelihood))
$(esc(:(__varinfo__))) = $accloglikelihood!!($(esc(:(__varinfo__))), $(esc(ex)))
end
end
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This appears to silently do nothing if the varinfo has no likelihood accumulator. One might ask what should it do instead? The problem is that, it's simply not clear what addlogprob should do when it can't find a likelihood accumulator. Should it attempt to add to the prior instead? Should it silently do nothing? Should it error? Personally, I would argue that given that the expected behaviour is not clear, it should error.

It's precisely because of cases like these that I think these logp functions should just be removed -- they simply don't have well-defined behaviour.

But if we insist on having them, we should at least make it error when there's no loglikelihood. I guess the easiest way is to remove the hasacc check.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The definition of addlogprob! would, after this PR, be that it adds to the log likelihood. (It's in the docstring too.) So it should either silently do nothing or error, but not e.g. add to the log prior.

The reason I think they should silently do nothing is that

  • it's in line with what tilde_observe!! statements do. They should add to the log likelihood, but if there's no accumulator for it they simply skip that step.
  • erroring would make it impossible to execute a model with @addlogprob! calls in it without accumulating log likelihood (which would then mean executing logpdf calls at every tilde_observe!!). This goes against the "choose what you want to accumulate" idea that having the separate accumulators (and PriorContext and LikelihoodContext) is largely about.

Comment on lines 59 to 60
macro addlogprob!(ex)
return quote
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think at the very least we should provide @addloglikelihood! and @addlogprior! macros as well

Comment on lines +156 to +158
# TODO(mhauru) We should write down the list of methods that any subtype of AbstractVarInfo
# has to implement. Not sure what the full list is for parameters values, but for
# accumulators we only need `getaccs` and `setaccs!!`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I put a list of AbstractVarInfo interface methods together the other day, feel free to ask me for it on Slack or something. We'd have to add the accumulator bits in, of course.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be a good contribution to #899

Return a boolean for whether `vi` has an accumulator with name `accname`.
"""
hasacc(vi::AbstractVarInfo, accname::Val) = haskey(getaccs(vi), accname)
function hassacc(vi::AbstractVarInfo, accname::Symbol)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
function hassacc(vi::AbstractVarInfo, accname::Symbol)
function hasacc(vi::AbstractVarInfo, accname::Symbol)

Comment on lines +288 to +296
function map_accumulator!!(func::Function, vi::AbstractVarInfo, accname::Symbol)
return error(
"""
The method map_accumulator!!(func::Function, vi::AbstractVarInfo, accname::Symbol)
does not exist. For type stability reasons use
map_accumulator!!(func::Function, vi::AbstractVarInfo, ::Val{accname}) instead.
"""
)
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these mostly reminders for ourselves, or is it meant to be user-facing?

If it's the former, I'd opine that we don't need these 'reminder methods' and we could just add a note to the docstring, because if you get a MethodError it'd be fairly easy to bring up the docstring and go 'ah, yes, I know what I need to do'.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are meant to be user-facing to the same degree as map_accumulator!! is user-facing (which is not really user-facing, but e.g. inference-algorithm-developer-facing). The reason I put these in is that morally map_accumulator!!(func::Function, vi::AbstractVarInfo, accname::Symbol) is exactly the right thing to call, and I might find it weird when I do the right thing to get a MethodError. The fact that you should wrap the accname argument in a Val is annoying technicality, due to type stability, that no one should waste brain cycles trying to understand the deeper meaning of. I myself made the mistake of calling these with a raw Symbol multiple times in developing this code and would have appreciated this reminder when trying to understand why my tests fail.

Comment on lines +328 to +334
if !(
names == (:logprior, :loglikelihood) ||
names == (:loglikelihood, :logprior) ||
names == (:logprior,) ||
names == (:loglikelihood,)
)
error("logp must have fields logprior and/or loglikelihood and no other fields.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're lucky there are only two types of logp, eh...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not the prettiest part of the codebase, but it's explicit and easy on the compiler.


An accumulator is an object that may change its value at every tilde_assume!! or
tilde_observe!! call based on the random variable in question. The obvious examples of
accumulators are the log prior and log likelihood. Others examples might be a variable that
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
accumulators are the log prior and log likelihood. Others examples might be a variable that
accumulators are the log prior and log likelihood. Other examples might be a variable that

return AccumulatorTuple(new_nt)
end

# END ACCUMULATOR TUPLE, BEGIN LOG PROB AND NUM PRODUCE ACCUMULATORS
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you consider splitting the file up into two to make it clear where the interface ends and where implementations of them begin?

I've long thought that each context should be in its own file, just never got round to actually separating them. I feel we could start accumulators off on a good footing by doing this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes, this was on my mind at some point but slipped. Would you keep AccumulatorTuple in the same file as AbstractAccumulator? I would, but not committed to it.

# Fields
$(TYPEDFIELDS)
"""
struct LogPrior{T} <: AbstractAccumulator
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

julia> using DynamicPPL; DynamicPPL.LogPrior("wat")
LogPrior("wat")

Maybe a type bound would be useful to prevent me from doing silly things? 😄

Suggested change
struct LogPrior{T} <: AbstractAccumulator
struct LogPrior{T<:Real} <: AbstractAccumulator

(Same for the remainder of the methods in this file)

split(::LogLikelihood{T}) where {T} = LogLikelihood(zero(T))
split(acc::NumProduce) = acc

combine(acc::LogPrior, acc2::LogPrior) = LogPrior(acc.logp + acc2.logp)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
combine(acc::LogPrior, acc2::LogPrior) = LogPrior(acc.logp + acc2.logp)
combine(acc::LogPrior{T}, acc2::LogPrior{T}) where {T} = LogPrior(acc.logp + acc2.logp)

I think it only makes sense to combine two LogPriors of the same type (.....?) If so, then this would make for a better error message. If not (i.e., if you think we need to add LogPriors of different types) then feel free to ignore

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unsure about this one. I wonder if there might be cases where e.g. on one thread some type conversion happens and you would rather either have it just work (if the two types can be added) or error with a message that comes from the underlying types. It does seem unlikely to run into this though.

Not what you asked but related: For + I would definitely like to support mixed sums to the same degree that the underlying types support mixed sums. It is also needed for instance for adding scalars to AD tracker types.

@penelopeysm penelopeysm self-requested a review April 27, 2025 21:47
@mhauru
Copy link
Member Author

mhauru commented Apr 28, 2025

Thanks both for the comments.

@penelopeysm:

What is the motivation for changing DefaultContext?

I find AccumulationContext to be a misleading name:

  1. Calling it AccumulationContext implies that other contexts don't have accumulators, which is false.
  2. Accumulators and contexts are supposed to be entirely orthogonal concepts, and AccumulatorContext conflates the two. To use some slight hyperbole, you probably wouldn't be in favour of calling it MetadataContext, because the context has nothing to do with the Metadata. The same principle applies to accumulators, IMO.

This isn't fully reflected in what's implemented in this PR, but I would like to get to a point where the meaning of AccumulatorContext would be "call accumulate_obssume!!". (I take the meaning of a context be an answer to the question "what should tilde_obbsume!! calls do?") I wonder if we would also need a NoopContext then that does nothing, in case someone doesn't want to do any accumulation. As implemented in this PR, we have a few accumulate_obbsume!! calls in various places, e.g. I think inside tilde_obssume!! for SamplingContext. I would like to get to a point where all of these would just refer to calling tilde_obbsume!! on AccumulationContext. So e.g. SamplingContext with PriorSampler or UniformSampler would sample the new values and then call tilde_obbsume!! on AccumulationContext.

If I actually manage to implement that, would you then be in favour of the renaming? I say "if", because there may be complications when interfacing with samplers that I don't see now, regarding what needs to happen before/after accumulation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants