diff --git a/HISTORY.md b/HISTORY.md index 9a70e8d1f..68650f9d1 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,23 @@ # DynamicPPL Changelog +## 0.37.0 + +**Breaking changes** + +### Accumulators + +This release overhauls how VarInfo objects track variables such as the log joint probability. The new approach is to use what we call accumulators: Objects that the VarInfo carries on it that may change their state at each `tilde_assume!!` and `tilde_observe!!` call based on the value of the variable in question. They replace both variables that were previously hard-coded in the `VarInfo` object (`logp` and `num_produce`) and some contexts. This brings with it a number of breaking changes: + + - `PriorContext` and `LikelihoodContext` no longer exist. By default, a `VarInfo` tracks both the log prior and the log likelihood separately, and they can be accessed with `getlogprior` and `getloglikelihood`. If you want to execute a model while only accumulating one of the two (to save clock cycles), you can do so by creating a `VarInfo` that only has one accumulator in it, e.g. `varinfo = setaccs!!(varinfo, (LogPriorAccumulator(),))`. + - `MiniBatchContext` does not exist anymore. It can be replaced by creating and using a custom accumulator that replaces the default `LikelihoodContext`. We may introduce such an accumulator in DynamicPPL in the future, but for now you'll need to do it yourself. + - `tilde_observe` and `observe` have been removed. `tilde_observe!!` still exists, and any contexts should modify its behaviour. We may further rework the call stack under `tilde_observe!!` in the near future. + - `tilde_assume` no longer returns the log density of the current assumption as its second return value. We may further rework the `tilde_assume!!` call stack as well. + - For literal observation statements like `0.0 ~ Normal(blahblah)` we used to call `tilde_observe!!` without the `vn` argument. This method no longer exists. Rather we call `tilde_observe!!` with `vn` set to `nothing`. + - `set/reset/increment_num_produce!` have become `set/reset/increment_num_produce!!` (note the second exclamation mark). They are no longer guaranteed to modify the `VarInfo` in place, and one should always use the return value. + - `@addlogprob!` now _always_ adds to the log likelihood. Previously it added to the log probability that the execution context specified, e.g. the log prior when using `PriorContext`. + - `getlogp` now returns a `NamedTuple` with keys `logprior` and `loglikelihood`. If you want the log joint probability, which is what `getlogp` used to return, use `getlogjoint`. + - Correspondingly `setlogp!!` and `acclogp!!` should now be called with a `NamedTuple` with keys `logprior` and `loglikelihood`. The `acclogp!!` method with a single scalar value has been deprecated and falls back on `accloglikelihood!!`, and the single scalar version of `setlogp!!` has been removed. Corresponding setter/accumulator functions exist for the log prior as well. + ## 0.36.0 **Breaking changes** diff --git a/Project.toml b/Project.toml index 01e2cb612..25c6acd24 100644 --- a/Project.toml +++ b/Project.toml @@ -21,6 +21,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" @@ -68,6 +69,7 @@ MCMCChains = "6" MacroTools = "0.5.6" Mooncake = "0.4.95" OrderedCollections = "1" +Printf = "1.10" Random = "1.6" Requires = "1" Statistics = "1" diff --git a/docs/src/api.md b/docs/src/api.md index 08522e2ce..e104193f2 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -160,10 +160,12 @@ returned(::Model) ## Utilities -It is possible to manually increase (or decrease) the accumulated log density from within a model function. +It is possible to manually increase (or decrease) the accumulated log likelihood or prior from within a model function. ```@docs @addlogprob! +@addloglikelihood! +@addlogprior! ``` Return values of the model function for a collection of samples can be obtained with [`returned(model, chain)`](@ref). @@ -328,9 +330,9 @@ The following functions were used for sequential Monte Carlo methods. ```@docs get_num_produce -set_num_produce! -increment_num_produce! -reset_num_produce! +set_num_produce!! +increment_num_produce!! +reset_num_produce!! setorder! set_retained_vns_del! ``` @@ -345,6 +347,22 @@ Base.empty! SimpleVarInfo ``` +### Accumulators + +The subtypes of [`AbstractVarInfo`](@ref) store the cumulative log prior and log likelihood, and sometimes other variables that change during executing, in what are called accumulators. + +```@docs +AbstractAccumulator +``` + +DynamicPPL provides the following default accumulators. + +```@docs +LogPriorAccumulator +LogLikelihoodAccumulator +NumProduceAccumulator +``` + ### Common API #### Accumulation of log-probabilities @@ -353,6 +371,13 @@ SimpleVarInfo getlogp setlogp!! acclogp!! +getlogjoint +getlogprior +setlogprior!! +acclogprior!! +getloglikelihood +setloglikelihood!! +accloglikelihood!! resetlogp!! ``` @@ -427,9 +452,6 @@ Contexts are subtypes of `AbstractPPL.AbstractContext`. ```@docs SamplingContext DefaultContext -LikelihoodContext -PriorContext -MiniBatchContext PrefixContext ConditionContext ``` @@ -476,7 +498,3 @@ DynamicPPL.Experimental.is_suitable_varinfo ```@docs tilde_assume ``` - -```@docs -tilde_observe -``` diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index 7fcbd6a7c..70f0f0182 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -48,10 +48,10 @@ end Sample from the posterior predictive distribution by executing `model` with parameters fixed to each sample in `chain`, and return the resulting `Chains`. -The `model` passed to `predict` is often different from the one used to generate `chain`. -Typically, the model from which `chain` originated treats certain variables as observed (i.e., -data points), while the model you pass to `predict` may mark these same variables as missing -or unobserved. Calling `predict` then leverages the previously inferred parameter values to +The `model` passed to `predict` is often different from the one used to generate `chain`. +Typically, the model from which `chain` originated treats certain variables as observed (i.e., +data points), while the model you pass to `predict` may mark these same variables as missing +or unobserved. Calling `predict` then leverages the previously inferred parameter values to simulate what new, unobserved data might look like, given your posterior beliefs. For each parameter configuration in `chain`: @@ -59,7 +59,7 @@ For each parameter configuration in `chain`: 2. Any variables not included in `chain` are sampled from their prior distributions. If `include_all` is `false`, the returned `Chains` will contain only those variables that were not fixed by -the samples in `chain`. This is useful when you want to sample only new variables from the posterior +the samples in `chain`. This is useful when you want to sample only new variables from the posterior predictive distribution. # Examples @@ -124,7 +124,7 @@ function DynamicPPL.predict( map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)), ) - return (varname_and_values=varname_vals, logp=DynamicPPL.getlogp(varinfo)) + return (varname_and_values=varname_vals, logp=DynamicPPL.getlogjoint(varinfo)) end chain_result = reduce( diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index c1c613d08..c8bbda020 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -6,6 +6,7 @@ using Bijectors using Compat using Distributions using OrderedCollections: OrderedCollections, OrderedDict +using Printf: Printf using AbstractMCMC: AbstractMCMC using ADTypes: ADTypes @@ -46,17 +47,28 @@ import Base: export AbstractVarInfo, VarInfo, SimpleVarInfo, + AbstractAccumulator, + LogLikelihoodAccumulator, + LogPriorAccumulator, + NumProduceAccumulator, push!!, empty!!, subset, getlogp, + getlogjoint, + getlogprior, + getloglikelihood, setlogp!!, + setlogprior!!, + setloglikelihood!!, acclogp!!, + acclogprior!!, + accloglikelihood!!, resetlogp!!, get_num_produce, - set_num_produce!, - reset_num_produce!, - increment_num_produce!, + set_num_produce!!, + reset_num_produce!!, + increment_num_produce!!, set_retained_vns_del!, is_flagged, set_flag!, @@ -92,15 +104,10 @@ export AbstractVarInfo, # Contexts SamplingContext, DefaultContext, - LikelihoodContext, - PriorContext, - MiniBatchContext, PrefixContext, ConditionContext, assume, - observe, tilde_assume, - tilde_observe, # Pseudo distributions NamedDist, NoDist, @@ -120,6 +127,8 @@ export AbstractVarInfo, to_submodel, # Convenience macros @addlogprob!, + @addlogprior!, + @addloglikelihood!, @submodel, value_iterator_from_chain, check_model, @@ -146,6 +155,9 @@ macro prob_str(str) )) end +# TODO(mhauru) We should write down the list of methods that any subtype of AbstractVarInfo +# has to implement. Not sure what the full list is for parameters values, but for +# accumulators we only need `getaccs` and `setaccs!!`. """ AbstractVarInfo @@ -166,6 +178,7 @@ include("varname.jl") include("distribution_wrappers.jl") include("contexts.jl") include("varnamedvector.jl") +include("accumulators.jl") include("abstract_varinfo.jl") include("threadsafe.jl") include("varinfo.jl") diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index f11b8a3ec..2f5da2c31 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -90,45 +90,289 @@ Return the `AbstractTransformation` related to `vi`. function transformation end # Accumulation of log-probabilities. +""" + getlogjoint(vi::AbstractVarInfo) + +Return the log of the joint probability of the observed data and parameters in `vi`. + +See also: [`getlogprior`](@ref), [`getloglikelihood`](@ref). +""" +getlogjoint(vi::AbstractVarInfo) = getlogprior(vi) + getloglikelihood(vi) + """ getlogp(vi::AbstractVarInfo) -Return the log of the joint probability of the observed data and parameters sampled in -`vi`. +Return a NamedTuple of the log prior and log likelihood probabilities. + +The keys are called `logprior` and `loglikelihood`. If either one is not present in `vi` an +error will be thrown. +""" +function getlogp(vi::AbstractVarInfo) + return (; logprior=getlogprior(vi), loglikelihood=getloglikelihood(vi)) +end + +""" + setaccs!!(vi::AbstractVarInfo, accs::AccumulatorTuple) + setaccs!!(vi::AbstractVarInfo, accs::NTuple{N,AbstractAccumulator} where {N}) + +Update the `AccumulatorTuple` of `vi` to `accs`, mutating if it makes sense. + +`setaccs!!(vi:AbstractVarInfo, accs::AccumulatorTuple) should be implemented by each subtype +of `AbstractVarInfo`. +""" +function setaccs!!(vi::AbstractVarInfo, accs::NTuple{N,AbstractAccumulator}) where {N} + return setaccs!!(vi, AccumulatorTuple(accs)) +end + +""" + getaccs(vi::AbstractVarInfo) + +Return the `AccumulatorTuple` of `vi`. + +This should be implemented by each subtype of `AbstractVarInfo`. +""" +function getaccs end + +""" + hasacc(vi::AbstractVarInfo, ::Val{accname}) where {accname} + +Return a boolean for whether `vi` has an accumulator with name `accname`. +""" +hasacc(vi::AbstractVarInfo, accname::Val) = haskey(getaccs(vi), accname) +function hasacc(vi::AbstractVarInfo, accname::Symbol) + return error( + """ + The method hasacc(vi::AbstractVarInfo, accname::Symbol) does not exist. For type + stability reasons use hasacc(vi::AbstractVarInfo, Val(accname)) instead. + """ + ) +end + +""" + acckeys(vi::AbstractVarInfo) + +Return the names of the accumulators in `vi`. +""" +acckeys(vi::AbstractVarInfo) = keys(getaccs(vi)) + """ -function getlogp end + getlogprior(vi::AbstractVarInfo) + +Return the log of the prior probability of the parameters in `vi`. +See also: [`getlogjoint`](@ref), [`getloglikelihood`](@ref), [`setlogprior!!`](@ref). """ - setlogp!!(vi::AbstractVarInfo, logp) +getlogprior(vi::AbstractVarInfo) = getacc(vi, Val(:LogPrior)).logp -Set the log of the joint probability of the observed data and parameters sampled in -`vi` to `logp`, mutating if it makes sense. """ -function setlogp!! end + getloglikelihood(vi::AbstractVarInfo) +Return the log of the likelihood probability of the observed data in `vi`. + +See also: [`getlogjoint`](@ref), [`getlogprior`](@ref), [`setloglikelihood!!`](@ref). """ - acclogp!!([context::AbstractContext, ]vi::AbstractVarInfo, logp) +getloglikelihood(vi::AbstractVarInfo) = getacc(vi, Val(:LogLikelihood)).logp + +""" + setacc!!(vi::AbstractVarInfo, acc::AbstractAccumulator) + +Add `acc` to the `AccumulatorTuple` of `vi`, mutating if it makes sense. + +If an accumulator with the same [`accumulator_name`](@ref) already exists, it will be +replaced. -Add `logp` to the value of the log of the joint probability of the observed data and -parameters sampled in `vi`, mutating if it makes sense. +See also: [`getaccs`](@ref). """ -function acclogp!!(context::AbstractContext, vi::AbstractVarInfo, logp) - return acclogp!!(NodeTrait(context), context, vi, logp) +function setacc!!(vi::AbstractVarInfo, acc::AbstractAccumulator) + return setaccs!!(vi, setacc!!(getaccs(vi), acc)) end -function acclogp!!(::IsLeaf, context::AbstractContext, vi::AbstractVarInfo, logp) - return acclogp!!(vi, logp) + +""" + setlogprior!!(vi::AbstractVarInfo, logp) + +Set the log of the prior probability of the parameters sampled in `vi` to `logp`. + +See also: [`setloglikelihood!!`](@ref), [`setlogp!!`](@ref), [`getlogprior`](@ref). +""" +setlogprior!!(vi::AbstractVarInfo, logp) = setacc!!(vi, LogPriorAccumulator(logp)) + +""" + setloglikelihood!!(vi::AbstractVarInfo, logp) + +Set the log of the likelihood probability of the observed data sampled in `vi` to `logp`. + +See also: [`setlogprior!!`](@ref), [`setlogp!!`](@ref), [`getloglikelihood`](@ref). +""" +setloglikelihood!!(vi::AbstractVarInfo, logp) = setacc!!(vi, LogLikelihoodAccumulator(logp)) + +""" + setlogp!!(vi::AbstractVarInfo, logp::NamedTuple) + +Set both the log prior and the log likelihood probabilities in `vi`. + +`logp` should have fields `logprior` and `loglikelihood` and no other fields. + +See also: [`setlogprior!!`](@ref), [`setloglikelihood!!`](@ref), [`getlogp`](@ref). +""" +function setlogp!!(vi::AbstractVarInfo, logp::NamedTuple{names}) where {names} + if !(names == (:logprior, :loglikelihood) || names == (:loglikelihood, :logprior)) + error("logp must have the fields logprior and loglikelihood and no other fields.") + end + vi = setlogprior!!(vi, logp.logprior) + vi = setloglikelihood!!(vi, logp.loglikelihood) + return vi end -function acclogp!!(::IsParent, context::AbstractContext, vi::AbstractVarInfo, logp) - return acclogp!!(childcontext(context), vi, logp) + +function setlogp!!(vi::AbstractVarInfo, logp::Number) + return error(""" + `setlogp!!(vi::AbstractVarInfo, logp::Number)` is no longer supported. Use + `setloglikelihood!!` and/or `setlogprior!!` instead. + """) +end + +""" + getacc(vi::AbstractVarInfo, ::Val{accname}) + +Return the `AbstractAccumulator` of `vi` with name `accname`. +""" +function getacc(vi::AbstractVarInfo, accname::Val) + return getacc(getaccs(vi), accname) +end +function getacc(vi::AbstractVarInfo, accname::Symbol) + return error( + """ + The method getacc(vi::AbstractVarInfo, accname::Symbol) does not exist. For type + stability reasons use getacc(vi::AbstractVarInfo, Val(accname)) instead. + """ + ) +end + +""" + accumulate_assume!!(vi::AbstractVarInfo, val, logjac, vn, right) + +Update all the accumulators of `vi` by calling `accumulate_assume!!` on them. +""" +function accumulate_assume!!(vi::AbstractVarInfo, val, logjac, vn, right) + return map_accumulators!!(acc -> accumulate_assume!!(acc, val, logjac, vn, right), vi) +end + +""" + accumulate_observe!!(vi::AbstractVarInfo, right, left, vn) + +Update all the accumulators of `vi` by calling `accumulate_observe!!` on them. +""" +function accumulate_observe!!(vi::AbstractVarInfo, right, left, vn) + return map_accumulators!!(acc -> accumulate_observe!!(acc, right, left, vn), vi) +end + +""" + map_accumulators!!(func::Function, vi::AbstractVarInfo) + +Update all accumulators of `vi` by calling `func` on them and replacing them with the return +values. +""" +function map_accumulators!!(func::Function, vi::AbstractVarInfo) + return setaccs!!(vi, map(func, getaccs(vi))) +end + +""" + map_accumulator!!(func::Function, vi::AbstractVarInfo, ::Val{accname}) where {accname} + +Update the accumulator `accname` of `vi` by calling `func` on it and replacing it with the +return value. +""" +function map_accumulator!!(func::Function, vi::AbstractVarInfo, accname::Val) + return setaccs!!(vi, map_accumulator(func, getaccs(vi), accname)) +end + +function map_accumulator!!(func::Function, vi::AbstractVarInfo, accname::Symbol) + return error( + """ + The method map_accumulator!!(func::Function, vi::AbstractVarInfo, accname::Symbol) + does not exist. For type stability reasons use + map_accumulator!!(func::Function, vi::AbstractVarInfo, ::Val{accname}) instead. + """ + ) +end + +""" + acclogprior!!(vi::AbstractVarInfo, logp) + +Add `logp` to the value of the log of the prior probability in `vi`. + +See also: [`accloglikelihood!!`](@ref), [`acclogp!!`](@ref), [`getlogprior`](@ref), [`setlogprior!!`](@ref). +""" +function acclogprior!!(vi::AbstractVarInfo, logp) + return map_accumulator!!(acc -> acc + LogPriorAccumulator(logp), vi, Val(:LogPrior)) +end + +""" + accloglikelihood!!(vi::AbstractVarInfo, logp) + +Add `logp` to the value of the log of the likelihood in `vi`. + +See also: [`accloglikelihood!!`](@ref), [`acclogp!!`](@ref), [`getloglikelihood`](@ref), [`setloglikelihood!!`](@ref). +""" +function accloglikelihood!!(vi::AbstractVarInfo, logp) + return map_accumulator!!( + acc -> acc + LogLikelihoodAccumulator(logp), vi, Val(:LogLikelihood) + ) +end + +""" + acclogp!!(vi::AbstractVarInfo, logp::NamedTuple; ignore_missing_accumulator::Bool=false) + +Add to both the log prior and the log likelihood probabilities in `vi`. + +`logp` should have fields `logprior` and/or `loglikelihood`, and no other fields. + +By default if the necessary accumulators are not in `vi` an error is thrown. If +`ignore_missing_accumulator` is set to `true` then this is silently ignored instead. +""" +function acclogp!!( + vi::AbstractVarInfo, logp::NamedTuple{names}; ignore_missing_accumulator=false +) where {names} + if !( + names == (:logprior, :loglikelihood) || + names == (:loglikelihood, :logprior) || + names == (:logprior,) || + names == (:loglikelihood,) + ) + error("logp must have fields logprior and/or loglikelihood and no other fields.") + end + if haskey(logp, :logprior) && + (!ignore_missing_accumulator || hasacc(vi, Val(:LogPrior))) + vi = acclogprior!!(vi, logp.logprior) + end + if haskey(logp, :loglikelihood) && + (!ignore_missing_accumulator || hasacc(vi, Val(:LogLikelihood))) + vi = accloglikelihood!!(vi, logp.loglikelihood) + end + return vi +end + +function acclogp!!(vi::AbstractVarInfo, logp::Number) + Base.depwarn( + "`acclogp!!(vi::AbstractVarInfo, logp::Number)` is deprecated. Use `accloglikelihood!!(vi, logp)` instead.", + :acclogp, + ) + return accloglikelihood!!(vi, logp) end """ resetlogp!!(vi::AbstractVarInfo) -Reset the value of the log of the joint probability of the observed data and parameters -sampled in `vi` to 0, mutating if it makes sense. +Reset the values of the log probabilities (prior and likelihood) in `vi` to zero. """ -resetlogp!!(vi::AbstractVarInfo) = setlogp!!(vi, zero(getlogp(vi))) +function resetlogp!!(vi::AbstractVarInfo) + if hasacc(vi, Val(:LogPrior)) + vi = map_accumulator!!(zero, vi, Val(:LogPrior)) + end + if hasacc(vi, Val(:LogLikelihood)) + vi = map_accumulator!!(zero, vi, Val(:LogLikelihood)) + end + return vi +end # Variables and their realizations. @doc """ @@ -566,8 +810,8 @@ function link!!( x = vi[:] y, logjac = with_logabsdet_jacobian(b, x) - lp_new = getlogp(vi) - logjac - vi_new = setlogp!!(unflatten(vi, y), lp_new) + lp_new = getlogprior(vi) - logjac + vi_new = setlogprior!!(unflatten(vi, y), lp_new) return settrans!!(vi_new, t) end @@ -578,8 +822,8 @@ function invlink!!( y = vi[:] x, logjac = with_logabsdet_jacobian(b, y) - lp_new = getlogp(vi) + logjac - vi_new = setlogp!!(unflatten(vi, x), lp_new) + lp_new = getlogprior(vi) + logjac + vi_new = setlogprior!!(unflatten(vi, x), lp_new) return settrans!!(vi_new, NoTransformation()) end @@ -725,7 +969,7 @@ end # Legacy code that is currently overloaded for the sake of simplicity. # TODO: Remove when possible. -increment_num_produce!(::AbstractVarInfo) = nothing +increment_num_produce!!(::AbstractVarInfo) = nothing """ from_internal_transform(varinfo::AbstractVarInfo, vn::VarName[, dist]) diff --git a/src/accumulators.jl b/src/accumulators.jl new file mode 100644 index 000000000..e241abf1c --- /dev/null +++ b/src/accumulators.jl @@ -0,0 +1,338 @@ +""" + AbstractAccumulator + +An abstract type for accumulators. + +An accumulator is an object that may change its value at every tilde_assume!! or +tilde_observe!! call based on the random variable in question. The obvious examples of +accumulators are the log prior and log likelihood. Other examples might be a variable that +counts the number of observations in a trace, or a list of the names of random variables +seen so far. + +An accumulator type `T <: AbstractAccumulator` must implement the following methods: +- `accumulator_name(acc::T)` or `accumulator_name(::Type{T})` +- `accumulate_observe!!(acc::T, right, left, vn)` +- `accumulate_assume!!(acc::T, val, logjac, vn, right)` + +To be able to work with multi-threading, it should also implement: +- `split(acc::T)` +- `combine(acc::T, acc2::T)` + +See the documentation for each of these functions for more details. +""" +abstract type AbstractAccumulator end + +""" + accumulator_name(acc::AbstractAccumulator) + +Return a Symbol which can be used as a name for `acc`. + +The name has to be unique in the sense that a `VarInfo` can only have one accumulator for +each name. The most typical case, and the default implementation, is that the name only +depends on the type of `acc`, not on its value. +""" +accumulator_name(acc::AbstractAccumulator) = accumulator_name(typeof(acc)) + +""" + accumulate_observe!!(acc::AbstractAccumulator, right, left, vn) + +Update `acc` in a `tilde_observe!!` call. Returns the updated `acc`. + +`vn` is the name of the variable being observed, `left` is the value of the variable, and +`right` is the distribution on the RHS of the tilde statement. `vn` is `nothing` in the case +of literal observations like `0.0 ~ Normal()`. + +`accumulate_observe!!` may mutate `acc`, but not any of the other arguments. + +See also: [`accumulate_assume!!`](@ref) +""" +function accumulate_observe!! end + +""" + accumulate_assume!!(acc::AbstractAccumulator, val, logjac, vn, right) + +Update `acc` in a `tilde_assume!!` call. Returns the updated `acc`. + +`vn` is the name of the variable being assumed, `val` is the value of the variable, and +`right` is the distribution on the RHS of the tilde statement. `logjac` is the log +determinant of the Jacobian of the transformation that was done to convert the value of `vn` +as it was given (e.g. by sampler operating in linked space) to `val`. + +`accumulate_assume!!` may mutate `acc`, but not any of the other arguments. + +See also: [`accumulate_observe!!`](@ref) +""" +function accumulate_assume!! end + +""" + split(acc::AbstractAccumulator) + +Return a new accumulator like `acc` but empty. + +The precise meaning of "empty" is that that the returned value should be such that +`combine(acc, split(acc))` is equal to `acc`. This is used in the context of multi-threading +where different threads may accumulate independently and the results are the combined. + +See also: [`combine`](@ref) +""" +function split end + +""" + combine(acc::AbstractAccumulator, acc2::AbstractAccumulator) + +Combine two accumulators of the same type. Returns a new accumulator. + +See also: [`split`](@ref) +""" +function combine end + +# TODO(mhauru) The existence of this function makes me sad. See comment in unflatten in +# src/varinfo.jl. +""" + convert_eltype(::Type{T}, acc::AbstractAccumulator) + +Convert `acc` to use element type `T`. + +What "element type" means depends on the type of `acc`. By default this function does +nothing. Accumulator types that need to hold differentiable values, such as dual numbers +used by various AD backends, should implement a method for this function. +""" +convert_eltype(::Type, acc::AbstractAccumulator) = acc + +# END ABSTRACT ACCUMULATOR, BEGIN ACCUMULATOR TUPLE + +""" + AccumulatorTuple{N,T<:NamedTuple} + +A collection of accumulators, stored as a `NamedTuple` of length `N` + +This is defined as a separate type to be able to dispatch on it cleanly and without method +ambiguities or conflicts with other `NamedTuple` types. We also use this type to enforce the +constraint that the name in the tuple for each accumulator `acc` must be +`accumulator_name(acc)`, and these names must be unique. + +The constructor can be called with a tuple or a `VarArgs` of `AbstractAccumulators`. The +names will be generated automatically. One can also call the constructor with a `NamedTuple` +but the names in the argument will be discarded in favour of the generated ones. +""" +struct AccumulatorTuple{N,T<:NamedTuple} + nt::T + + function AccumulatorTuple(t::T) where {N,T<:NTuple{N,AbstractAccumulator}} + names = map(accumulator_name, t) + nt = NamedTuple{names}(t) + return new{N,typeof(nt)}(nt) + end +end + +AccumulatorTuple(accs::Vararg{AbstractAccumulator}) = AccumulatorTuple(accs) +AccumulatorTuple(nt::NamedTuple) = AccumulatorTuple(tuple(nt...)) + +# When showing with text/plain, leave out information about the wrapper AccumulatorTuple. +Base.show(io::IO, mime::MIME"text/plain", at::AccumulatorTuple) = show(io, mime, at.nt) +Base.getindex(at::AccumulatorTuple, idx) = at.nt[idx] +Base.length(::AccumulatorTuple{N}) where {N} = N +Base.iterate(at::AccumulatorTuple, args...) = iterate(at.nt, args...) +function Base.haskey(at::AccumulatorTuple, ::Val{accname}) where {accname} + # @inline to ensure constant propagation can resolve this to a compile-time constant. + @inline return haskey(at.nt, accname) +end +Base.keys(at::AccumulatorTuple) = keys(at.nt) + +function Base.convert(::Type{AccumulatorTuple{N,T}}, accs::AccumulatorTuple{N}) where {N,T} + return AccumulatorTuple(convert(T, accs.nt)) +end + +""" + setacc!!(at::AccumulatorTuple, acc::AbstractAccumulator) + +Add `acc` to `at`. Returns a new `AccumulatorTuple`. + +If an `AbstractAccumulator` with the same `accumulator_name` already exists in `at` it is +replaced. `at` will never be mutated, but the name has the `!!` for consistency with the +corresponding function for `AbstractVarInfo`. +""" +function setacc!!(at::AccumulatorTuple, acc::AbstractAccumulator) + accname = accumulator_name(acc) + new_nt = merge(at.nt, NamedTuple{(accname,)}((acc,))) + return AccumulatorTuple(new_nt) +end + +""" + getacc(at::AccumulatorTuple, ::Val{accname}) + +Get the accumulator with name `accname` from `at`. +""" +function getacc(at::AccumulatorTuple, ::Val{accname}) where {accname} + return at[accname] +end + +function Base.map(func::Function, at::AccumulatorTuple) + return AccumulatorTuple(map(func, at.nt)) +end + +""" + map_accumulator(func::Function, at::AccumulatorTuple, ::Val{accname}) + +Update the accumulator with name `accname` in `at` by calling `func` on it. + +Returns a new `AccumulatorTuple`. +""" +function map_accumulator( + func::Function, at::AccumulatorTuple, ::Val{accname} +) where {accname} + # Would like to write this as + # return Accessors.@set at.nt[accname] = func(at[accname], args...) + # for readability, but that one isn't type stable due to + # https://github.com/JuliaObjects/Accessors.jl/issues/198 + new_val = func(at[accname]) + new_nt = merge(at.nt, NamedTuple{(accname,)}((new_val,))) + return AccumulatorTuple(new_nt) +end + +# END ACCUMULATOR TUPLE, BEGIN LOG PROB AND NUM PRODUCE ACCUMULATORS + +""" + LogPriorAccumulator{T<:Real} <: AbstractAccumulator + +An accumulator that tracks the cumulative log prior during model execution. + +# Fields +$(TYPEDFIELDS) +""" +struct LogPriorAccumulator{T<:Real} <: AbstractAccumulator + "the scalar log prior value" + logp::T +end + +""" + LogPriorAccumulator{T}() + +Create a new `LogPriorAccumulator` accumulator with the log prior initialized to zero. +""" +LogPriorAccumulator{T}() where {T<:Real} = LogPriorAccumulator(zero(T)) +LogPriorAccumulator() = LogPriorAccumulator{LogProbType}() + +""" + LogLikelihoodAccumulator{T<:Real} <: AbstractAccumulator + +An accumulator that tracks the cumulative log likelihood during model execution. + +# Fields +$(TYPEDFIELDS) +""" +struct LogLikelihoodAccumulator{T<:Real} <: AbstractAccumulator + "the scalar log likelihood value" + logp::T +end + +""" + LogLikelihoodAccumulator{T}() + +Create a new `LogLikelihoodAccumulator` accumulator with the log likelihood initialized to zero. +""" +LogLikelihoodAccumulator{T}() where {T<:Real} = LogLikelihoodAccumulator(zero(T)) +LogLikelihoodAccumulator() = LogLikelihoodAccumulator{LogProbType}() + +""" + NumProduceAccumulator{T} <: AbstractAccumulator + +An accumulator that tracks the number of observations during model execution. + +# Fields +$(TYPEDFIELDS) +""" +struct NumProduceAccumulator{T<:Integer} <: AbstractAccumulator + "the number of observations" + num::T +end + +""" + NumProduceAccumulator{T<:Integer}() + +Create a new `NumProduceAccumulator` accumulator with the number of observations initialized to zero. +""" +NumProduceAccumulator{T}() where {T<:Integer} = NumProduceAccumulator(zero(T)) +NumProduceAccumulator() = NumProduceAccumulator{Int}() + +function Base.show(io::IO, acc::LogPriorAccumulator) + return print(io, "LogPriorAccumulator($(repr(acc.logp)))") +end +function Base.show(io::IO, acc::LogLikelihoodAccumulator) + return print(io, "LogLikelihoodAccumulator($(repr(acc.logp)))") +end +function Base.show(io::IO, acc::NumProduceAccumulator) + return print(io, "NumProduceAccumulator($(repr(acc.num)))") +end + +accumulator_name(::Type{<:LogPriorAccumulator}) = :LogPrior +accumulator_name(::Type{<:LogLikelihoodAccumulator}) = :LogLikelihood +accumulator_name(::Type{<:NumProduceAccumulator}) = :NumProduce + +split(::LogPriorAccumulator{T}) where {T} = LogPriorAccumulator(zero(T)) +split(::LogLikelihoodAccumulator{T}) where {T} = LogLikelihoodAccumulator(zero(T)) +split(acc::NumProduceAccumulator) = acc + +function combine(acc::LogPriorAccumulator, acc2::LogPriorAccumulator) + return LogPriorAccumulator(acc.logp + acc2.logp) +end +function combine(acc::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator) + return LogLikelihoodAccumulator(acc.logp + acc2.logp) +end +function combine(acc::NumProduceAccumulator, acc2::NumProduceAccumulator) + return NumProduceAccumulator(max(acc.num, acc2.num)) +end + +function Base.:+(acc1::LogPriorAccumulator, acc2::LogPriorAccumulator) + return LogPriorAccumulator(acc1.logp + acc2.logp) +end +function Base.:+(acc1::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator) + return LogLikelihoodAccumulator(acc1.logp + acc2.logp) +end +increment(acc::NumProduceAccumulator) = NumProduceAccumulator(acc.num + oneunit(acc.num)) + +Base.zero(acc::LogPriorAccumulator) = LogPriorAccumulator(zero(acc.logp)) +Base.zero(acc::LogLikelihoodAccumulator) = LogLikelihoodAccumulator(zero(acc.logp)) +Base.zero(acc::NumProduceAccumulator) = NumProduceAccumulator(zero(acc.num)) + +function accumulate_assume!!(acc::LogPriorAccumulator, val, logjac, vn, right) + return acc + LogPriorAccumulator(logpdf(right, val) + logjac) +end +accumulate_observe!!(acc::LogPriorAccumulator, right, left, vn) = acc + +accumulate_assume!!(acc::LogLikelihoodAccumulator, val, logjac, vn, right) = acc +function accumulate_observe!!(acc::LogLikelihoodAccumulator, right, left, vn) + # Note that it's important to use the loglikelihood function here, not logpdf, because + # they handle vectors differently: + # https://github.com/JuliaStats/Distributions.jl/issues/1972 + return acc + LogLikelihoodAccumulator(Distributions.loglikelihood(right, left)) +end + +accumulate_assume!!(acc::NumProduceAccumulator, val, logjac, vn, right) = acc +accumulate_observe!!(acc::NumProduceAccumulator, right, left, vn) = increment(acc) + +function Base.convert(::Type{LogPriorAccumulator{T}}, acc::LogPriorAccumulator) where {T} + return LogPriorAccumulator(convert(T, acc.logp)) +end +function Base.convert( + ::Type{LogLikelihoodAccumulator{T}}, acc::LogLikelihoodAccumulator +) where {T} + return LogLikelihoodAccumulator(convert(T, acc.logp)) +end +function Base.convert( + ::Type{NumProduceAccumulator{T}}, acc::NumProduceAccumulator +) where {T} + return NumProduceAccumulator(convert(T, acc.num)) +end + +# TODO(mhauru) +# We ignore the convert_eltype calls for NumProduceAccumulator, by letting them fallback on +# convert_eltype(::AbstractAccumulator, ::Type). This is because they are only used to +# deal with dual number types of AD backends, which shouldn't concern NumProduceAccumulator. This is +# horribly hacky and should be fixed. See also comment in `unflatten` in `src/varinfo.jl`. +function convert_eltype(::Type{T}, acc::LogPriorAccumulator) where {T} + return LogPriorAccumulator(convert(T, acc.logp)) +end +function convert_eltype(::Type{T}, acc::LogLikelihoodAccumulator) where {T} + return LogLikelihoodAccumulator(convert(T, acc.logp)) +end diff --git a/src/compiler.jl b/src/compiler.jl index 6f7489b8e..9eb4835d3 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -418,7 +418,7 @@ function generate_tilde_literal(left, right) @gensym value return quote $value, __varinfo__ = $(DynamicPPL.tilde_observe!!)( - __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ + __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, nothing, __varinfo__ ) $value end diff --git a/src/context_implementations.jl b/src/context_implementations.jl index eb025dec8..bb1c66c8e 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -14,27 +14,6 @@ alg_str(spl::Sampler) = string(nameof(typeof(spl.alg))) require_gradient(spl::Sampler) = false require_particles(spl::Sampler) = false -# Allows samplers, etc. to hook into the final logp accumulation in the tilde-pipeline. -function acclogp_assume!!(context::AbstractContext, vi::AbstractVarInfo, logp) - return acclogp_assume!!(NodeTrait(acclogp_assume!!, context), context, vi, logp) -end -function acclogp_assume!!(::IsParent, context::AbstractContext, vi::AbstractVarInfo, logp) - return acclogp_assume!!(childcontext(context), vi, logp) -end -function acclogp_assume!!(::IsLeaf, context::AbstractContext, vi::AbstractVarInfo, logp) - return acclogp!!(context, vi, logp) -end - -function acclogp_observe!!(context::AbstractContext, vi::AbstractVarInfo, logp) - return acclogp_observe!!(NodeTrait(acclogp_observe!!, context), context, vi, logp) -end -function acclogp_observe!!(::IsParent, context::AbstractContext, vi::AbstractVarInfo, logp) - return acclogp_observe!!(childcontext(context), vi, logp) -end -function acclogp_observe!!(::IsLeaf, context::AbstractContext, vi::AbstractVarInfo, logp) - return acclogp!!(context, vi, logp) -end - # assume """ tilde_assume(context::SamplingContext, right, vn, vi) @@ -52,36 +31,18 @@ function tilde_assume(context::SamplingContext, right, vn, vi) return tilde_assume(context.rng, context.context, context.sampler, right, vn, vi) end -# Leaf contexts function tilde_assume(context::AbstractContext, args...) - return tilde_assume(NodeTrait(tilde_assume, context), context, args...) + return tilde_assume(childcontext(context), args...) end -function tilde_assume(::IsLeaf, context::AbstractContext, right, vn, vi) +function tilde_assume(::DefaultContext, right, vn, vi) return assume(right, vn, vi) end -function tilde_assume(::IsParent, context::AbstractContext, args...) - return tilde_assume(childcontext(context), args...) -end function tilde_assume(rng::Random.AbstractRNG, context::AbstractContext, args...) - return tilde_assume(NodeTrait(tilde_assume, context), rng, context, args...) -end -function tilde_assume( - ::IsLeaf, rng::Random.AbstractRNG, context::AbstractContext, sampler, right, vn, vi -) - return assume(rng, sampler, right, vn, vi) -end -function tilde_assume( - ::IsParent, rng::Random.AbstractRNG, context::AbstractContext, args... -) return tilde_assume(rng, childcontext(context), args...) end - -function tilde_assume(::LikelihoodContext, right, vn, vi) - return assume(nodist(right), vn, vi) -end -function tilde_assume(rng::Random.AbstractRNG, ::LikelihoodContext, sampler, right, vn, vi) - return assume(rng, sampler, nodist(right), vn, vi) +function tilde_assume(rng::Random.AbstractRNG, ::DefaultContext, sampler, right, vn, vi) + return assume(rng, sampler, right, vn, vi) end function tilde_assume(context::PrefixContext, right, vn, vi) @@ -137,55 +98,37 @@ function tilde_assume!!(context, right, vn, vi) end rand_like!!(right, context, vi) else - value, logp, vi = tilde_assume(context, right, vn, vi) - value, acclogp_assume!!(context, vi, logp) + value, vi = tilde_assume(context, right, vn, vi) + return value, vi end end # observe """ - tilde_observe(context::SamplingContext, right, left, vi) + tilde_observe!!(context::SamplingContext, right, left, vi) Handle observed constants with a `context` associated with a sampler. -Falls back to `tilde_observe(context.context, context.sampler, right, left, vi)`. +Falls back to `tilde_observe!!(context.context, right, left, vi)`. """ -function tilde_observe(context::SamplingContext, right, left, vi) - return tilde_observe(context.context, context.sampler, right, left, vi) +function tilde_observe!!(context::SamplingContext, right, left, vn, vi) + return tilde_observe!!(context.context, right, left, vn, vi) end -# Leaf contexts -function tilde_observe(context::AbstractContext, args...) - return tilde_observe(NodeTrait(tilde_observe, context), context, args...) -end -tilde_observe(::IsLeaf, context::AbstractContext, args...) = observe(args...) -function tilde_observe(::IsParent, context::AbstractContext, args...) - return tilde_observe(childcontext(context), args...) -end - -tilde_observe(::PriorContext, right, left, vi) = 0, vi -tilde_observe(::PriorContext, sampler, right, left, vi) = 0, vi - -# `MiniBatchContext` -function tilde_observe(context::MiniBatchContext, right, left, vi) - logp, vi = tilde_observe(context.context, right, left, vi) - return context.loglike_scalar * logp, vi -end -function tilde_observe(context::MiniBatchContext, sampler, right, left, vi) - logp, vi = tilde_observe(context.context, sampler, right, left, vi) - return context.loglike_scalar * logp, vi +function tilde_observe!!(context::AbstractContext, right, left, vn, vi) + return tilde_observe!!(childcontext(context), right, left, vn, vi) end # `PrefixContext` -function tilde_observe(context::PrefixContext, right, left, vi) - return tilde_observe(context.context, right, left, vi) -end -function tilde_observe(context::PrefixContext, sampler, right, left, vi) - return tilde_observe(context.context, sampler, right, left, vi) +function tilde_observe!!(context::PrefixContext, right, left, vn, vi) + # In the observe case, unlike assume, `vn` may be `nothing` if the LHS is a literal + # value. + prefixed_varname = vn !== nothing ? prefix(context, vn) : vn + return tilde_observe!!(context.context, right, left, prefixed_varname, vi) end """ - tilde_observe!!(context, right, left, vname, vi) + tilde_observe!!(context, right, left, vn, vi) Handle observed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the observed value and updated `vi`. @@ -193,46 +136,27 @@ accumulate the log probability, and return the observed value and updated `vi`. Falls back to `tilde_observe!!(context, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ -function tilde_observe!!(context, right, left, vname, vi) +function tilde_observe!!(context::DefaultContext, right, left, vn, vi) is_rhs_model(right) && throw( ArgumentError( "`~` with a model on the right-hand side of an observe statement is not supported", ), ) - return tilde_observe!!(context, right, left, vi) -end - -""" - tilde_observe(context, right, left, vi) - -Handle observed constants, e.g., `1.0 ~ Normal()`, accumulate the log probability, and -return the observed value. - -By default, calls `tilde_observe(context, right, left, vi)` and accumulates the log -probability of `vi` with the returned value. -""" -function tilde_observe!!(context, right, left, vi) - is_rhs_model(right) && throw( - ArgumentError( - "`~` with a model on the right-hand side of an observe statement is not supported", - ), - ) - logp, vi = tilde_observe(context, right, left, vi) - return left, acclogp_observe!!(context, vi, logp) + vi = accumulate_observe!!(vi, right, left, vn) + return left, vi end function assume(rng::Random.AbstractRNG, spl::Sampler, dist) return error("DynamicPPL.assume: unmanaged inference algorithm: $(typeof(spl))") end -function observe(spl::Sampler, weight) - return error("DynamicPPL.observe: unmanaged inference algorithm: $(typeof(spl))") -end - # fallback without sampler function assume(dist::Distribution, vn::VarName, vi) - r, logp = invlink_with_logpdf(vi, vn, dist) - return r, logp, vi + y = getindex_internal(vi, vn) + f = from_maybe_linked_internal_transform(vi, vn, dist) + x, logjac = with_logabsdet_jacobian(f, y) + vi = accumulate_assume!!(vi, x, logjac, vn, dist) + return x, vi end # TODO: Remove this thing. @@ -254,8 +178,7 @@ function assume( r = init(rng, dist, sampler) f = to_maybe_linked_internal_transform(vi, vn, dist) # TODO(mhauru) This should probably be call a function called setindex_internal! - # Also, if we use !! we shouldn't ignore the return value. - BangBang.setindex!!(vi, f(r), vn) + vi = BangBang.setindex!!(vi, f(r), vn) setorder!(vi, vn, get_num_produce(vi)) else # Otherwise we just extract it. @@ -265,22 +188,16 @@ function assume( r = init(rng, dist, sampler) if istrans(vi) f = to_linked_internal_transform(vi, vn, dist) - push!!(vi, vn, f(r), dist) + vi = push!!(vi, vn, f(r), dist) # By default `push!!` sets the transformed flag to `false`. - settrans!!(vi, true, vn) + vi = settrans!!(vi, true, vn) else - push!!(vi, vn, r, dist) + vi = push!!(vi, vn, r, dist) end end # HACK: The above code might involve an `invlink` somewhere, etc. so we need to correct. logjac = logabsdetjac(istrans(vi, vn) ? link_transform(dist) : identity, r) - return r, logpdf(dist, r) - logjac, vi -end - -# default fallback (used e.g. by `SampleFromPrior` and `SampleUniform`) -observe(sampler::AbstractSampler, right, left, vi) = observe(right, left, vi) -function observe(right::Distribution, left, vi) - increment_num_produce!(vi) - return Distributions.loglikelihood(right, left), vi + vi = accumulate_assume!!(vi, r, -logjac, vn, dist) + return r, vi end diff --git a/src/contexts.jl b/src/contexts.jl index 8ac085663..addadfa1a 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -45,15 +45,17 @@ effectively updating the child context. # Examples ```jldoctest +julia> using DynamicPPL: DynamicTransformationContext + julia> ctx = SamplingContext(); julia> DynamicPPL.childcontext(ctx) DefaultContext() -julia> ctx_prior = DynamicPPL.setchildcontext(ctx, PriorContext()); # only compute the logprior +julia> ctx_prior = DynamicPPL.setchildcontext(ctx, DynamicTransformationContext{true}()); julia> DynamicPPL.childcontext(ctx_prior) -PriorContext() +DynamicTransformationContext{true}() ``` """ setchildcontext @@ -78,7 +80,7 @@ original leaf context of `left`. # Examples ```jldoctest -julia> using DynamicPPL: leafcontext, setleafcontext, childcontext, setchildcontext, AbstractContext +julia> using DynamicPPL: leafcontext, setleafcontext, childcontext, setchildcontext, AbstractContext, DynamicTransformationContext julia> struct ParentContext{C} <: AbstractContext context::C @@ -96,8 +98,8 @@ julia> ctx = ParentContext(ParentContext(DefaultContext())) ParentContext(ParentContext(DefaultContext())) julia> # Replace the leaf context with another leaf. - leafcontext(setleafcontext(ctx, PriorContext())) -PriorContext() + leafcontext(setleafcontext(ctx, DynamicTransformationContext{true}())) +DynamicTransformationContext{true}() julia> # Append another parent context. setleafcontext(ctx, ParentContext(DefaultContext())) @@ -129,7 +131,7 @@ setleafcontext(::IsLeaf, ::IsLeaf, left, right) = right Create a context that allows you to sample parameters with the `sampler` when running the model. The `context` determines how the returned log density is computed when running the model. -See also: [`DefaultContext`](@ref), [`LikelihoodContext`](@ref), [`PriorContext`](@ref) +See also: [`DefaultContext`](@ref) """ struct SamplingContext{S<:AbstractSampler,C<:AbstractContext,R} <: AbstractContext rng::R @@ -189,52 +191,11 @@ getsampler(::IsLeaf, ::AbstractContext) = error("No sampler found in context") """ struct DefaultContext <: AbstractContext end -The `DefaultContext` is used by default to compute the log joint probability of the data -and parameters when running the model. +The `DefaultContext` is used by default to accumulate values like the log joint probability +when running the model. """ struct DefaultContext <: AbstractContext end -NodeTrait(context::DefaultContext) = IsLeaf() - -""" - PriorContext <: AbstractContext - -A leaf context resulting in the exclusion of likelihood terms when running the model. -""" -struct PriorContext <: AbstractContext end -NodeTrait(context::PriorContext) = IsLeaf() - -""" - LikelihoodContext <: AbstractContext - -A leaf context resulting in the exclusion of prior terms when running the model. -""" -struct LikelihoodContext <: AbstractContext end -NodeTrait(context::LikelihoodContext) = IsLeaf() - -""" - struct MiniBatchContext{Tctx, T} <: AbstractContext - context::Tctx - loglike_scalar::T - end - -The `MiniBatchContext` enables the computation of -`log(prior) + s * log(likelihood of a batch)` when running the model, where `s` is the -`loglike_scalar` field, typically equal to `the number of data points / batch size`. -This is useful in batch-based stochastic gradient descent algorithms to be optimizing -`log(prior) + log(likelihood of all the data points)` in the expectation. -""" -struct MiniBatchContext{Tctx,T} <: AbstractContext - context::Tctx - loglike_scalar::T -end -function MiniBatchContext(context=DefaultContext(); batch_size, npoints) - return MiniBatchContext(context, npoints / batch_size) -end -NodeTrait(context::MiniBatchContext) = IsParent() -childcontext(context::MiniBatchContext) = context.context -function setchildcontext(parent::MiniBatchContext, child) - return MiniBatchContext(child, parent.loglike_scalar) -end +NodeTrait(::DefaultContext) = IsLeaf() """ PrefixContext(vn::VarName[, context::AbstractContext]) diff --git a/src/debug_utils.jl b/src/debug_utils.jl index 15ef8fb01..238cd422d 100644 --- a/src/debug_utils.jl +++ b/src/debug_utils.jl @@ -76,7 +76,6 @@ Base.@kwdef struct AssumeStmt <: Stmt varname right value - logp varinfo = nothing end @@ -89,16 +88,12 @@ function Base.show(io::IO, stmt::AssumeStmt) print(io, " ") print(io, RESULT_SYMBOL) print(io, " ") - print(io, stmt.value) - print(io, " (logprob = ") - print(io, stmt.logp) - return print(io, ")") + return print(io, stmt.value) end Base.@kwdef struct ObserveStmt <: Stmt left right - logp varinfo = nothing end @@ -107,10 +102,7 @@ function Base.show(io::IO, stmt::ObserveStmt) print(io, "observe: ") show_right(io, stmt.left) print(io, " ~ ") - show_right(io, stmt.right) - print(io, " (logprob = ") - print(io, stmt.logp) - return print(io, ")") + return show_right(io, stmt.right) end # Some utility methods for extracting information from a trace. @@ -252,12 +244,11 @@ function record_pre_tilde_assume!(context::DebugContext, vn, dist, varinfo) return nothing end -function record_post_tilde_assume!(context::DebugContext, vn, dist, value, logp, varinfo) +function record_post_tilde_assume!(context::DebugContext, vn, dist, value, varinfo) stmt = AssumeStmt(; varname=vn, right=dist, value=value, - logp=logp, varinfo=context.record_varinfo ? varinfo : nothing, ) if context.record_statements @@ -268,19 +259,17 @@ end function DynamicPPL.tilde_assume(context::DebugContext, right, vn, vi) record_pre_tilde_assume!(context, vn, right, vi) - value, logp, vi = DynamicPPL.tilde_assume(childcontext(context), right, vn, vi) - record_post_tilde_assume!(context, vn, right, value, logp, vi) - return value, logp, vi + value, vi = DynamicPPL.tilde_assume(childcontext(context), right, vn, vi) + record_post_tilde_assume!(context, vn, right, value, vi) + return value, vi end function DynamicPPL.tilde_assume( rng::Random.AbstractRNG, context::DebugContext, sampler, right, vn, vi ) record_pre_tilde_assume!(context, vn, right, vi) - value, logp, vi = DynamicPPL.tilde_assume( - rng, childcontext(context), sampler, right, vn, vi - ) - record_post_tilde_assume!(context, vn, right, value, logp, vi) - return value, logp, vi + value, vi = DynamicPPL.tilde_assume(rng, childcontext(context), sampler, right, vn, vi) + record_post_tilde_assume!(context, vn, right, value, vi) + return value, vi end # observe @@ -304,12 +293,9 @@ function record_pre_tilde_observe!(context::DebugContext, left, dist, varinfo) end end -function record_post_tilde_observe!(context::DebugContext, left, right, logp, varinfo) +function record_post_tilde_observe!(context::DebugContext, left, right, varinfo) stmt = ObserveStmt(; - left=left, - right=right, - logp=logp, - varinfo=context.record_varinfo ? varinfo : nothing, + left=left, right=right, varinfo=context.record_varinfo ? varinfo : nothing ) if context.record_statements push!(context.statements, stmt) @@ -317,17 +303,17 @@ function record_post_tilde_observe!(context::DebugContext, left, right, logp, va return nothing end -function DynamicPPL.tilde_observe(context::DebugContext, right, left, vi) +function DynamicPPL.tilde_observe!!(context::DebugContext, right, left, vn, vi) record_pre_tilde_observe!(context, left, right, vi) - logp, vi = DynamicPPL.tilde_observe(childcontext(context), right, left, vi) - record_post_tilde_observe!(context, left, right, logp, vi) - return logp, vi + vi = DynamicPPL.tilde_observe!!(childcontext(context), right, left, vn, vi) + record_post_tilde_observe!(context, left, right, vi) + return vi end -function DynamicPPL.tilde_observe(context::DebugContext, sampler, right, left, vi) +function DynamicPPL.tilde_observe!!(context::DebugContext, sampler, right, left, vn, vi) record_pre_tilde_observe!(context, left, right, vi) - logp, vi = DynamicPPL.tilde_observe(childcontext(context), sampler, right, left, vi) - record_post_tilde_observe!(context, left, right, logp, vi) - return logp, vi + vi = DynamicPPL.tilde_observe!!(childcontext(context), sampler, right, left, vn, vi) + record_post_tilde_observe!(context, left, right, vi) + return vi end _conditioned_varnames(d::AbstractDict) = keys(d) @@ -413,7 +399,7 @@ julia> issuccess true julia> print(trace) - assume: x ~ Normal{Float64}(μ=0.0, σ=1.0) ⟼ -0.670252 (logprob = -1.14356) + assume: x ~ Normal{Float64}(μ=0.0, σ=1.0) ⟼ -0.670252 julia> issuccess, trace = check_model_and_trace(rng, demo_correct() | (x = 1.0,)); @@ -421,7 +407,7 @@ julia> issuccess true julia> print(trace) -observe: 1.0 ~ Normal{Float64}(μ=0.0, σ=1.0) (logprob = -1.41894) +observe: 1.0 ~ Normal{Float64}(μ=0.0, σ=1.0) ``` ## Incorrect model diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index a42855f05..1b5e9b8c4 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -51,7 +51,7 @@ $(FIELDS) ```jldoctest julia> using Distributions -julia> using DynamicPPL: LogDensityFunction, contextualize +julia> using DynamicPPL: LogDensityFunction, setaccs!! julia> @model function demo(x) m ~ Normal() @@ -78,8 +78,8 @@ julia> # By default it uses `VarInfo` under the hood, but this is not necessary. julia> LogDensityProblems.logdensity(f, [0.0]) -2.3378770664093453 -julia> # This also respects the context in `model`. - f_prior = LogDensityFunction(contextualize(model, DynamicPPL.PriorContext()), VarInfo(model)); +julia> # LogDensityFunction respects the accumulators in VarInfo: + f_prior = LogDensityFunction(model, setaccs!!(VarInfo(model), (LogPriorAccumulator(),))); julia> LogDensityProblems.logdensity(f_prior, [0.0]) == logpdf(Normal(), 0.0) true @@ -174,14 +174,26 @@ end Evaluate the log density of the given `model` at the given parameter values `x`, using the given `varinfo` and `context`. Note that the `varinfo` argument is provided -only for its structure, in the sense that the parameters from the vector `x` are inserted into -it, and its own parameters are discarded. +only for its structure, in the sense that the parameters from the vector `x` are inserted +into it, and its own parameters are discarded. It does, however, determine whether the log +prior, likelihood, or joint is returned, based on which accumulators are set in it. """ function logdensity_at( x::AbstractVector, model::Model, varinfo::AbstractVarInfo, context::AbstractContext ) varinfo_new = unflatten(varinfo, x) - return getlogp(last(evaluate!!(model, varinfo_new, context))) + varinfo_eval = last(evaluate!!(model, varinfo_new, context)) + has_prior = hasacc(varinfo_eval, Val(:LogPrior)) + has_likelihood = hasacc(varinfo_eval, Val(:LogLikelihood)) + if has_prior && has_likelihood + return getlogjoint(varinfo_eval) + elseif has_prior + return getlogprior(varinfo_eval) + elseif has_likelihood + return getloglikelihood(varinfo_eval) + else + error("LogDensityFunction: varinfo tracks neither log prior nor log likelihood") + end end ### LogDensityProblems interface diff --git a/src/model.jl b/src/model.jl index c7c4bdf57..e8f2f3528 100644 --- a/src/model.jl +++ b/src/model.jl @@ -900,7 +900,7 @@ See also: [`evaluate_threadunsafe!!`](@ref) function evaluate_threadsafe!!(model, varinfo, context) wrapper = ThreadSafeVarInfo(resetlogp!!(varinfo)) result, wrapper_new = _evaluate!!(model, wrapper, context) - return result, setlogp!!(wrapper_new.varinfo, getlogp(wrapper_new)) + return result, setaccs!!(wrapper_new.varinfo, getaccs(wrapper_new)) end """ @@ -1010,7 +1010,7 @@ Return the log joint probability of variables `varinfo` for the probabilistic `m See [`logprior`](@ref) and [`loglikelihood`](@ref). """ function logjoint(model::Model, varinfo::AbstractVarInfo) - return getlogp(last(evaluate!!(model, varinfo, DefaultContext()))) + return getlogjoint(last(evaluate!!(model, varinfo, DefaultContext()))) end """ @@ -1057,7 +1057,14 @@ Return the log prior probability of variables `varinfo` for the probabilistic `m See also [`logjoint`](@ref) and [`loglikelihood`](@ref). """ function logprior(model::Model, varinfo::AbstractVarInfo) - return getlogp(last(evaluate!!(model, varinfo, PriorContext()))) + # Remove other accumulators from varinfo, since they are unnecessary. + logprior = if hasacc(varinfo, Val(:LogPrior)) + getacc(varinfo, Val(:LogPrior)) + else + LogPriorAccumulator() + end + varinfo = setaccs!!(deepcopy(varinfo), (logprior,)) + return getlogprior(last(evaluate!!(model, varinfo, DefaultContext()))) end """ @@ -1104,7 +1111,14 @@ Return the log likelihood of variables `varinfo` for the probabilistic `model`. See also [`logjoint`](@ref) and [`logprior`](@ref). """ function Distributions.loglikelihood(model::Model, varinfo::AbstractVarInfo) - return getlogp(last(evaluate!!(model, varinfo, LikelihoodContext()))) + # Remove other accumulators from varinfo, since they are unnecessary. + loglikelihood = if hasacc(varinfo, Val(:LogLikelihood)) + getacc(varinfo, Val(:LogLikelihood)) + else + LogLikelihoodAccumulator() + end + varinfo = setaccs!!(deepcopy(varinfo), (loglikelihood,)) + return getloglikelihood(last(evaluate!!(model, varinfo, DefaultContext()))) end """ @@ -1358,7 +1372,7 @@ We can check that the log joint probability of the model accumulated in `vi` is ```jldoctest submodel-to_submodel julia> x = vi[@varname(a.x)]; -julia> getlogp(vi) ≈ logpdf(Normal(), x) + logpdf(Uniform(0, 1 + abs(x)), 0.4) +julia> getlogjoint(vi) ≈ logpdf(Normal(), x) + logpdf(Uniform(0, 1 + abs(x)), 0.4) true ``` @@ -1422,7 +1436,7 @@ julia> logprior = logpdf(Normal(), sub1_x) + logpdf(Normal(), sub2_x); julia> loglikelihood = logpdf(Uniform(-1 - abs(sub1_x), 1 + abs(sub2_x)), 0.4); -julia> getlogp(vi) ≈ logprior + loglikelihood +julia> getlogjoint(vi) ≈ logprior + loglikelihood true ``` diff --git a/src/pointwise_logdensities.jl b/src/pointwise_logdensities.jl index cb9ea4894..b6b97c8f9 100644 --- a/src/pointwise_logdensities.jl +++ b/src/pointwise_logdensities.jl @@ -1,142 +1,117 @@ -# Context version -struct PointwiseLogdensityContext{A,Ctx} <: AbstractContext - logdensities::A - context::Ctx -end +""" + PointwiseLogProbAccumulator{whichlogprob,KeyType,D<:AbstractDict{KeyType}} <: AbstractAccumulator -function PointwiseLogdensityContext( - likelihoods=OrderedDict{VarName,Vector{Float64}}(), - context::AbstractContext=DefaultContext(), -) - return PointwiseLogdensityContext{typeof(likelihoods),typeof(context)}( - likelihoods, context - ) -end +An accumulator that stores the log-probabilities of each variable in a model. -NodeTrait(::PointwiseLogdensityContext) = IsParent() -childcontext(context::PointwiseLogdensityContext) = context.context -function setchildcontext(context::PointwiseLogdensityContext, child) - return PointwiseLogdensityContext(context.logdensities, child) -end +Internally this context stores the log-probabilities in a dictionary, where the keys are +the variable names and the values are vectors of log-probabilities. Each element in a vector +corresponds to one execution of the model. -function Base.push!( - context::PointwiseLogdensityContext{<:AbstractDict{VarName,Vector{Float64}}}, - vn::VarName, - logp::Real, -) - lookup = context.logdensities - ℓ = get!(lookup, vn, Float64[]) - return push!(ℓ, logp) +`whichlogprob` is a symbol that can be `:both`, `:prior`, or `:likelihood`, and specifies +which log-probabilities to store in the accumulator. `KeyType` is the type by which variable +names are stored, and should be `String` or `VarName`. `D` is the type of the dictionary +used internally to store the log-probabilities, by default +`OrderedDict{KeyType, Vector{LogProbType}}`. +""" +struct PointwiseLogProbAccumulator{whichlogprob,KeyType,D<:AbstractDict{KeyType}} <: + AbstractAccumulator + logps::D end -function Base.push!( - context::PointwiseLogdensityContext{<:AbstractDict{VarName,Float64}}, - vn::VarName, - logp::Real, -) - return context.logdensities[vn] = logp +function PointwiseLogProbAccumulator{whichlogprob}(logps) where {whichlogprob} + return PointwiseLogProbAccumulator{whichlogprob,keytype(logps),typeof(logps)}(logps) end -function Base.push!( - context::PointwiseLogdensityContext{<:AbstractDict{String,Vector{Float64}}}, - vn::VarName, - logp::Real, -) - lookup = context.logdensities - ℓ = get!(lookup, string(vn), Float64[]) - return push!(ℓ, logp) +function PointwiseLogProbAccumulator{whichlogprob}() where {whichlogprob} + return PointwiseLogProbAccumulator{whichlogprob,VarName}() end -function Base.push!( - context::PointwiseLogdensityContext{<:AbstractDict{String,Float64}}, - vn::VarName, - logp::Real, -) - return context.logdensities[string(vn)] = logp +function PointwiseLogProbAccumulator{whichlogprob,KeyType}() where {whichlogprob,KeyType} + logps = OrderedDict{KeyType,Vector{LogProbType}}() + return PointwiseLogProbAccumulator{whichlogprob,KeyType,typeof(logps)}(logps) end -function Base.push!( - context::PointwiseLogdensityContext{<:AbstractDict{String,Vector{Float64}}}, - vn::String, - logp::Real, -) - lookup = context.logdensities - ℓ = get!(lookup, vn, Float64[]) - return push!(ℓ, logp) +function Base.push!(acc::PointwiseLogProbAccumulator, vn, logp) + logps = acc.logps + # The last(fieldtypes(eltype(...))) gets the type of the values, rather than the keys. + T = last(fieldtypes(eltype(logps))) + logpvec = get!(logps, vn, T()) + return push!(logpvec, logp) end function Base.push!( - context::PointwiseLogdensityContext{<:AbstractDict{String,Float64}}, - vn::String, - logp::Real, -) - return context.logdensities[vn] = logp + acc::PointwiseLogProbAccumulator{whichlogprob,String}, vn::VarName, logp +) where {whichlogprob} + return push!(acc, string(vn), logp) end -function _include_prior(context::PointwiseLogdensityContext) - return leafcontext(context) isa Union{PriorContext,DefaultContext} -end -function _include_likelihood(context::PointwiseLogdensityContext) - return leafcontext(context) isa Union{LikelihoodContext,DefaultContext} +function accumulator_name( + ::Type{<:PointwiseLogProbAccumulator{whichlogprob}} +) where {whichlogprob} + return Symbol("PointwiseLogProbAccumulator{$whichlogprob}") end -function tilde_observe!!(context::PointwiseLogdensityContext, right, left, vi) - # Defer literal `observe` to child-context. - return tilde_observe!!(context.context, right, left, vi) +function split(acc::PointwiseLogProbAccumulator{whichlogprob}) where {whichlogprob} + return PointwiseLogProbAccumulator{whichlogprob}(empty(acc.logps)) end -function tilde_observe!!(context::PointwiseLogdensityContext, right, left, vn, vi) - # Completely defer to child context if we are not tracking likelihoods. - if !(_include_likelihood(context)) - return tilde_observe!!(context.context, right, left, vn, vi) - end - # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. - # we have to intercept the call to `tilde_observe!`. - logp, vi = tilde_observe(context.context, right, left, vi) - - # Track loglikelihood value. - push!(context, vn, logp) - - return left, acclogp!!(vi, logp) +function combine( + acc::PointwiseLogProbAccumulator{whichlogprob}, + acc2::PointwiseLogProbAccumulator{whichlogprob}, +) where {whichlogprob} + return PointwiseLogProbAccumulator{whichlogprob}(mergewith(vcat, acc.logps, acc2.logps)) end -# Note on submodels (penelopeysm) -# -# We don't need to overload tilde_observe!! for Sampleables (yet), because it -# is currently not possible to evaluate a model with a Sampleable on the RHS -# of an observe statement. -# -# Note that calling tilde_assume!! on a Sampleable does not necessarily imply -# that there are no observe statements inside the Sampleable. There could well -# be likelihood terms in there, which must be included in the returned logp. -# See e.g. the `demo_dot_assume_observe_submodel` demo model. -# -# This is handled by passing the same context to rand_like!!, which figures out -# which terms to include using the context, and also mutates the context and vi -# appropriately. Thus, we don't need to check against _include_prior(context) -# here. -function tilde_assume!!(context::PointwiseLogdensityContext, right::Sampleable, vn, vi) - value, vi = DynamicPPL.rand_like!!(right, context, vi) - return value, vi +function accumulate_assume!!( + acc::PointwiseLogProbAccumulator{whichlogprob}, val, logjac, vn, right +) where {whichlogprob} + if whichlogprob == :both || whichlogprob == :prior + # T is the element type of the vectors that are the values of `acc.logps`. Usually + # it's LogProbType. + T = eltype(last(fieldtypes(eltype(acc.logps)))) + subacc = accumulate_assume!!(LogPriorAccumulator{T}(), val, logjac, vn, right) + push!(acc, vn, subacc.logp) + end + return acc end -function tilde_assume!!(context::PointwiseLogdensityContext, right, vn, vi) - !_include_prior(context) && return (tilde_assume!!(context.context, right, vn, vi)) - value, logp, vi = tilde_assume(context.context, right, vn, vi) - # Track loglikelihood value. - push!(context, vn, logp) - return value, acclogp!!(vi, logp) +function accumulate_observe!!( + acc::PointwiseLogProbAccumulator{whichlogprob}, right, left, vn +) where {whichlogprob} + # If `vn` is nothing the LHS of ~ is a literal and we don't have a name to attach this + # acc to, and thus do nothing. + if vn === nothing + return acc + end + if whichlogprob == :both || whichlogprob == :likelihood + # T is the element type of the vectors that are the values of `acc.logps`. Usually + # it's LogProbType. + T = eltype(last(fieldtypes(eltype(acc.logps)))) + subacc = accumulate_observe!!(LogLikelihoodAccumulator{T}(), right, left, vn) + push!(acc, vn, subacc.logp) + end + return acc end """ - pointwise_logdensities(model::Model, chain::Chains, keytype = String) + pointwise_logdensities( + model::Model, + chain::Chains, + keytype=String, + context=DefaultContext(), + ::Val{whichlogprob}=Val(:both), + ) Runs `model` on each sample in `chain` returning a `OrderedDict{String, Matrix{Float64}}` with keys corresponding to symbols of the variables, and values being matrices of shape `(num_chains, num_samples)`. `keytype` specifies what the type of the keys used in the returned `OrderedDict` are. -Currently, only `String` and `VarName` are supported. +Currently, only `String` and `VarName` are supported. `context` is the evaluation context, +and `whichlogprob` specifies which log-probabilities to compute. It can be `:both`, +`:prior`, or `:likelihood`. + +See also: [`pointwise_loglikelihoods`](@ref), [`pointwise_loglikelihoods`](@ref). # Notes Say `y` is a `Vector` of `n` i.i.d. `Normal(μ, σ)` variables, with `μ` and `σ` @@ -234,14 +209,19 @@ julia> m = demo([1.0; 1.0]); julia> ℓ = pointwise_logdensities(m, VarInfo(m)); first.((ℓ[@varname(x[1])], ℓ[@varname(x[2])])) (-1.4189385332046727, -1.4189385332046727) ``` - """ function pointwise_logdensities( - model::Model, chain, keytype::Type{T}=String, context::AbstractContext=DefaultContext() -) where {T} + model::Model, + chain, + ::Type{KeyType}=String, + context::AbstractContext=DefaultContext(), + ::Val{whichlogprob}=Val(:both), +) where {KeyType,whichlogprob} # Get the data by executing the model once vi = VarInfo(model) - point_context = PointwiseLogdensityContext(OrderedDict{T,Vector{Float64}}(), context) + + AccType = PointwiseLogProbAccumulator{whichlogprob,KeyType} + vi = setaccs!!(vi, (AccType(),)) iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) for (sample_idx, chain_idx) in iters @@ -249,26 +229,28 @@ function pointwise_logdensities( setval!(vi, chain, sample_idx, chain_idx) # Execute model - model(vi, point_context) + vi = last(evaluate!!(model, vi, context)) end + logps = getacc(vi, Val(accumulator_name(AccType))).logps niters = size(chain, 1) nchains = size(chain, 3) logdensities = OrderedDict( - varname => reshape(logliks, niters, nchains) for - (varname, logliks) in point_context.logdensities + varname => reshape(vals, niters, nchains) for (varname, vals) in logps ) return logdensities end function pointwise_logdensities( - model::Model, varinfo::AbstractVarInfo, context::AbstractContext=DefaultContext() -) - point_context = PointwiseLogdensityContext( - OrderedDict{VarName,Vector{Float64}}(), context - ) - model(varinfo, point_context) - return point_context.logdensities + model::Model, + varinfo::AbstractVarInfo, + context::AbstractContext=DefaultContext(), + ::Val{whichlogprob}=Val(:both), +) where {whichlogprob} + AccType = PointwiseLogProbAccumulator{whichlogprob} + varinfo = setaccs!!(varinfo, (AccType(),)) + varinfo = last(evaluate!!(model, varinfo, context)) + return getacc(varinfo, Val(accumulator_name(AccType))).logps end """ @@ -277,29 +259,19 @@ end Compute the pointwise log-likelihoods of the model given the chain. This is the same as `pointwise_logdensities(model, chain, context)`, but only including the likelihood terms. -See also: [`pointwise_logdensities`](@ref). + +See also: [`pointwise_logdensities`](@ref), [`pointwise_prior_logdensities`](@ref). """ function pointwise_loglikelihoods( - model::Model, - chain, - keytype::Type{T}=String, - context::AbstractContext=LikelihoodContext(), + model::Model, chain, keytype::Type{T}=String, context::AbstractContext=DefaultContext() ) where {T} - if !(leafcontext(context) isa LikelihoodContext) - throw(ArgumentError("Leaf context should be a LikelihoodContext")) - end - - return pointwise_logdensities(model, chain, T, context) + return pointwise_logdensities(model, chain, T, context, Val(:likelihood)) end function pointwise_loglikelihoods( - model::Model, varinfo::AbstractVarInfo, context::AbstractContext=LikelihoodContext() + model::Model, varinfo::AbstractVarInfo, context::AbstractContext=DefaultContext() ) - if !(leafcontext(context) isa LikelihoodContext) - throw(ArgumentError("Leaf context should be a LikelihoodContext")) - end - - return pointwise_logdensities(model, varinfo, context) + return pointwise_logdensities(model, varinfo, context, Val(:likelihood)) end """ @@ -308,24 +280,17 @@ end Compute the pointwise log-prior-densities of the model given the chain. This is the same as `pointwise_logdensities(model, chain, context)`, but only including the prior terms. -See also: [`pointwise_logdensities`](@ref). + +See also: [`pointwise_logdensities`](@ref), [`pointwise_loglikelihoods`](@ref). """ function pointwise_prior_logdensities( - model::Model, chain, keytype::Type{T}=String, context::AbstractContext=PriorContext() + model::Model, chain, keytype::Type{T}=String, context::AbstractContext=DefaultContext() ) where {T} - if !(leafcontext(context) isa PriorContext) - throw(ArgumentError("Leaf context should be a PriorContext")) - end - - return pointwise_logdensities(model, chain, T, context) + return pointwise_logdensities(model, chain, T, context, Val(:prior)) end function pointwise_prior_logdensities( - model::Model, varinfo::AbstractVarInfo, context::AbstractContext=PriorContext() + model::Model, varinfo::AbstractVarInfo, context::AbstractContext=DefaultContext() ) - if !(leafcontext(context) isa PriorContext) - throw(ArgumentError("Leaf context should be a PriorContext")) - end - - return pointwise_logdensities(model, varinfo, context) + return pointwise_logdensities(model, varinfo, context, Val(:prior)) end diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index abf14b8fc..257ccb004 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -125,18 +125,18 @@ Evaluation in transformed space of course also works: ```jldoctest simplevarinfo-general julia> vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), true) -Transformed SimpleVarInfo((x = -1.0,), 0.0) +Transformed SimpleVarInfo((x = -1.0,), (LogLikelihood = LogLikelihoodAccumulator(0.0), LogPrior = LogPriorAccumulator(0.0))) julia> # (✓) Positive probability mass on negative numbers! - getlogp(last(DynamicPPL.evaluate!!(m, vi, DynamicPPL.DefaultContext()))) + getlogjoint(last(DynamicPPL.evaluate!!(m, vi, DynamicPPL.DefaultContext()))) -1.3678794411714423 julia> # While if we forget to indicate that it's transformed: vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), false) -SimpleVarInfo((x = -1.0,), 0.0) +SimpleVarInfo((x = -1.0,), (LogLikelihood = LogLikelihoodAccumulator(0.0), LogPrior = LogPriorAccumulator(0.0))) julia> # (✓) No probability mass on negative numbers! - getlogp(last(DynamicPPL.evaluate!!(m, vi, DynamicPPL.DefaultContext()))) + getlogjoint(last(DynamicPPL.evaluate!!(m, vi, DynamicPPL.DefaultContext()))) -Inf ``` @@ -188,41 +188,39 @@ ERROR: type NamedTuple has no field b [...] ``` """ -struct SimpleVarInfo{NT,T,C<:AbstractTransformation} <: AbstractVarInfo +struct SimpleVarInfo{NT,Accs<:AccumulatorTuple where {N},C<:AbstractTransformation} <: + AbstractVarInfo "underlying representation of the realization represented" values::NT - "holds the accumulated log-probability" - logp::T + "tuple of accumulators for things like log prior and log likelihood" + accs::Accs "represents whether it assumes variables to be transformed" transformation::C end transformation(vi::SimpleVarInfo) = vi.transformation -# Makes things a bit more readable vs. putting `Float64` everywhere. -const SIMPLEVARINFO_DEFAULT_ELTYPE = Float64 - -function SimpleVarInfo{NT,T}(values, logp) where {NT,T} - return SimpleVarInfo{NT,T,NoTransformation}(values, logp, NoTransformation()) +function SimpleVarInfo(values, accs) + return SimpleVarInfo(values, accs, NoTransformation()) end -function SimpleVarInfo{T}(θ) where {T<:Real} - return SimpleVarInfo{typeof(θ),T}(θ, zero(T)) +function SimpleVarInfo{T}(values) where {T<:Real} + return SimpleVarInfo( + values, AccumulatorTuple(LogLikelihoodAccumulator{T}(), LogPriorAccumulator{T}()) + ) end - -# Constructors without type-specification. -SimpleVarInfo(θ) = SimpleVarInfo{SIMPLEVARINFO_DEFAULT_ELTYPE}(θ) -function SimpleVarInfo(θ::Union{<:NamedTuple,<:AbstractDict}) - return if isempty(θ) +function SimpleVarInfo(values) + return SimpleVarInfo{LogProbType}(values) +end +function SimpleVarInfo(values::Union{<:NamedTuple,<:AbstractDict}) + return if isempty(values) # Can't infer from values, so we just use default. - SimpleVarInfo{SIMPLEVARINFO_DEFAULT_ELTYPE}(θ) + SimpleVarInfo{LogProbType}(values) else # Infer from `values`. - SimpleVarInfo{float_type_with_fallback(infer_nested_eltype(typeof(θ)))}(θ) + SimpleVarInfo{float_type_with_fallback(infer_nested_eltype(typeof(values)))}(values) end end -SimpleVarInfo(values, logp) = SimpleVarInfo{typeof(values),typeof(logp)}(values, logp) - # Using `kwargs` to specify the values. function SimpleVarInfo{T}(; kwargs...) where {T<:Real} return SimpleVarInfo{T}(NamedTuple(kwargs)) @@ -235,7 +233,7 @@ end function SimpleVarInfo( model::Model, args::Union{AbstractVarInfo,AbstractSampler,AbstractContext}... ) - return SimpleVarInfo{Float64}(model, args...) + return SimpleVarInfo{LogProbType}(model, args...) end function SimpleVarInfo{T}( model::Model, args::Union{AbstractVarInfo,AbstractSampler,AbstractContext}... @@ -244,14 +242,14 @@ function SimpleVarInfo{T}( end # Constructor from `VarInfo`. -function SimpleVarInfo(vi::NTVarInfo, (::Type{D})=NamedTuple; kwargs...) where {D} - return SimpleVarInfo{eltype(getlogp(vi))}(vi, D; kwargs...) +function SimpleVarInfo(vi::NTVarInfo, ::Type{D}) where {D} + values = values_as(vi, D) + return SimpleVarInfo(values, deepcopy(getaccs(vi))) end -function SimpleVarInfo{T}( - vi::VarInfo{<:NamedTuple{names}}, ::Type{D} -) where {T<:Real,names,D} +function SimpleVarInfo{T}(vi::NTVarInfo, ::Type{D}) where {T<:Real,D} values = values_as(vi, D) - return SimpleVarInfo(values, convert(T, getlogp(vi))) + accs = map(acc -> convert_eltype(T, acc), getaccs(vi)) + return SimpleVarInfo(values, accs) end function untyped_simple_varinfo(model::Model) @@ -265,12 +263,12 @@ function typed_simple_varinfo(model::Model) end function unflatten(svi::SimpleVarInfo, x::AbstractVector) - logp = getlogp(svi) vals = unflatten(svi.values, x) - T = eltype(x) - return SimpleVarInfo{typeof(vals),T,typeof(svi.transformation)}( - vals, T(logp), svi.transformation - ) + # TODO(mhauru) See comment in unflatten in src/varinfo.jl for why this conversion is + # required but undesireable. + T = float_type_with_fallback(eltype(x)) + accs = map(acc -> convert_eltype(T, acc), getaccs(svi)) + return SimpleVarInfo(vals, accs, svi.transformation) end function BangBang.empty!!(vi::SimpleVarInfo) @@ -278,21 +276,8 @@ function BangBang.empty!!(vi::SimpleVarInfo) end Base.isempty(vi::SimpleVarInfo) = isempty(vi.values) -getlogp(vi::SimpleVarInfo) = vi.logp -getlogp(vi::SimpleVarInfo{<:Any,<:Ref}) = vi.logp[] - -setlogp!!(vi::SimpleVarInfo, logp) = Accessors.@set vi.logp = logp -acclogp!!(vi::SimpleVarInfo, logp) = Accessors.@set vi.logp = getlogp(vi) + logp - -function setlogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) - vi.logp[] = logp - return vi -end - -function acclogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) - vi.logp[] += logp - return vi -end +getaccs(vi::SimpleVarInfo) = vi.accs +setaccs!!(vi::SimpleVarInfo, accs::AccumulatorTuple) = Accessors.@set vi.accs = accs """ keys(vi::SimpleVarInfo) @@ -302,12 +287,12 @@ Return an iterator of keys present in `vi`. Base.keys(vi::SimpleVarInfo) = keys(vi.values) Base.keys(vi::SimpleVarInfo{<:NamedTuple}) = map(k -> VarName{k}(), keys(vi.values)) -function Base.show(io::IO, ::MIME"text/plain", svi::SimpleVarInfo) +function Base.show(io::IO, mime::MIME"text/plain", svi::SimpleVarInfo) if !(svi.transformation isa NoTransformation) print(io, "Transformed ") end - return print(io, "SimpleVarInfo(", svi.values, ", ", svi.logp, ")") + return print(io, "SimpleVarInfo(", svi.values, ", ", repr(mime, getaccs(svi)), ")") end function Base.getindex(vi::SimpleVarInfo, vn::VarName, dist::Distribution) @@ -454,11 +439,11 @@ _subset(x::VarNamedVector, vns) = subset(x, vns) # `merge` function Base.merge(varinfo_left::SimpleVarInfo, varinfo_right::SimpleVarInfo) values = merge(varinfo_left.values, varinfo_right.values) - logp = getlogp(varinfo_right) + accs = deepcopy(getaccs(varinfo_right)) transformation = merge_transformations( varinfo_left.transformation, varinfo_right.transformation ) - return SimpleVarInfo(values, logp, transformation) + return SimpleVarInfo(values, accs, transformation) end # Context implementations @@ -473,9 +458,11 @@ function assume( ) value = init(rng, dist, sampler) # Transform if we're working in unconstrained space. - value_raw = to_maybe_linked_internal(vi, vn, dist, value) + f = to_maybe_linked_internal_transform(vi, vn, dist) + value_raw, logjac = with_logabsdet_jacobian(f, value) vi = BangBang.push!!(vi, vn, value_raw, dist) - return value, Bijectors.logpdf_with_trans(dist, value, istrans(vi, vn)), vi + vi = accumulate_assume!!(vi, value, -logjac, vn, dist) + return value, vi end # NOTE: We don't implement `settrans!!(vi, trans, vn)`. @@ -497,8 +484,8 @@ islinked(vi::SimpleVarInfo) = istrans(vi) values_as(vi::SimpleVarInfo) = vi.values values_as(vi::SimpleVarInfo{<:T}, ::Type{T}) where {T} = vi.values -function values_as(vi::SimpleVarInfo{<:Any,T}, ::Type{Vector}) where {T} - isempty(vi) && return T[] +function values_as(vi::SimpleVarInfo, ::Type{Vector}) + isempty(vi) && return Any[] return mapreduce(tovec, vcat, values(vi.values)) end function values_as(vi::SimpleVarInfo, ::Type{D}) where {D<:AbstractDict} @@ -613,12 +600,11 @@ function link!!( vi::SimpleVarInfo{<:NamedTuple}, ::Model, ) - # TODO: Make sure that `spl` is respected. b = inverse(t.bijector) x = vi.values y, logjac = with_logabsdet_jacobian(b, x) - lp_new = getlogp(vi) - logjac - vi_new = setlogp!!(Accessors.@set(vi.values = y), lp_new) + vi_new = Accessors.@set(vi.values = y) + vi_new = acclogprior!!(vi_new, -logjac) return settrans!!(vi_new, t) end @@ -627,12 +613,11 @@ function invlink!!( vi::SimpleVarInfo{<:NamedTuple}, ::Model, ) - # TODO: Make sure that `spl` is respected. b = t.bijector y = vi.values x, logjac = with_logabsdet_jacobian(b, y) - lp_new = getlogp(vi) + logjac - vi_new = setlogp!!(Accessors.@set(vi.values = x), lp_new) + vi_new = Accessors.@set(vi.values = x) + vi_new = acclogprior!!(vi_new, logjac) return settrans!!(vi_new, NoTransformation()) end @@ -645,13 +630,4 @@ function from_linked_internal_transform(vi::SimpleVarInfo, ::VarName, dist) return invlink_transform(dist) end -# Threadsafe stuff. -# For `SimpleVarInfo` we don't really need `Ref` so let's not use it. -function ThreadSafeVarInfo(vi::SimpleVarInfo) - return ThreadSafeVarInfo(vi, zeros(typeof(getlogp(vi)), Threads.nthreads())) -end -function ThreadSafeVarInfo(vi::SimpleVarInfo{<:Any,<:Ref}) - return ThreadSafeVarInfo(vi, [Ref(zero(getlogp(vi))) for _ in 1:Threads.nthreads()]) -end - has_varnamedvector(vi::SimpleVarInfo) = vi.values isa VarNamedVector diff --git a/src/submodel_macro.jl b/src/submodel_macro.jl index 5f1ec95ec..bd08b427e 100644 --- a/src/submodel_macro.jl +++ b/src/submodel_macro.jl @@ -45,7 +45,7 @@ We can check that the log joint probability of the model accumulated in `vi` is ```jldoctest submodel julia> x = vi[@varname(x)]; -julia> getlogp(vi) ≈ logpdf(Normal(), x) + logpdf(Uniform(0, 1 + abs(x)), 0.4) +julia> getlogjoint(vi) ≈ logpdf(Normal(), x) + logpdf(Uniform(0, 1 + abs(x)), 0.4) true ``` """ @@ -124,7 +124,7 @@ julia> logprior = logpdf(Normal(), sub1_x) + logpdf(Normal(), sub2_x); julia> loglikelihood = logpdf(Uniform(-1 - abs(sub1_x), 1 + abs(sub2_x)), 0.4); -julia> getlogp(vi) ≈ logprior + loglikelihood +julia> getlogjoint(vi) ≈ logprior + loglikelihood true ``` diff --git a/src/test_utils/contexts.jl b/src/test_utils/contexts.jl index 7404a9af7..08acdfada 100644 --- a/src/test_utils/contexts.jl +++ b/src/test_utils/contexts.jl @@ -3,34 +3,6 @@ # # Utilities for testing contexts. -""" -Context that multiplies each log-prior by mod -used to test whether varwise_logpriors respects child-context. -""" -struct TestLogModifyingChildContext{T,Ctx} <: DynamicPPL.AbstractContext - mod::T - context::Ctx -end -function TestLogModifyingChildContext( - mod=1.2, context::DynamicPPL.AbstractContext=DynamicPPL.DefaultContext() -) - return TestLogModifyingChildContext{typeof(mod),typeof(context)}(mod, context) -end - -DynamicPPL.NodeTrait(::TestLogModifyingChildContext) = DynamicPPL.IsParent() -DynamicPPL.childcontext(context::TestLogModifyingChildContext) = context.context -function DynamicPPL.setchildcontext(context::TestLogModifyingChildContext, child) - return TestLogModifyingChildContext(context.mod, child) -end -function DynamicPPL.tilde_assume(context::TestLogModifyingChildContext, right, vn, vi) - value, logp, vi = DynamicPPL.tilde_assume(context.context, right, vn, vi) - return value, logp * context.mod, vi -end -function DynamicPPL.tilde_observe(context::TestLogModifyingChildContext, right, left, vi) - logp, vi = DynamicPPL.tilde_observe(context.context, right, left, vi) - return logp * context.mod, vi -end - # Dummy context to test nested behaviors. struct TestParentContext{C<:DynamicPPL.AbstractContext} <: DynamicPPL.AbstractContext context::C @@ -61,7 +33,7 @@ function test_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Mod # To see change, let's make sure we're using a different leaf context than the current. leafcontext_new = if DynamicPPL.leafcontext(context) isa DefaultContext - PriorContext() + DynamicPPL.DynamicTransformationContext{false}() else DefaultContext() end diff --git a/src/test_utils/models.jl b/src/test_utils/models.jl index e29614982..12f88acad 100644 --- a/src/test_utils/models.jl +++ b/src/test_utils/models.jl @@ -148,7 +148,7 @@ Simple model for which [`default_transformation`](@ref) returns a [`StaticTransf 1.5 ~ Normal(m, sqrt(s)) 2.0 ~ Normal(m, sqrt(s)) - return (; s, m, x=[1.5, 2.0], logp=getlogp(__varinfo__)) + return (; s, m, x=[1.5, 2.0]) end function DynamicPPL.default_transformation(::Model{typeof(demo_static_transformation)}) @@ -194,7 +194,7 @@ end m ~ product_distribution(Normal.(0, sqrt.(s))) x ~ MvNormal(m, Diagonal(s)) - return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x) end function logprior_true(model::Model{typeof(demo_dot_assume_observe)}, s, m) return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) @@ -225,7 +225,7 @@ end end x ~ MvNormal(m, Diagonal(s)) - return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x) end function logprior_true(model::Model{typeof(demo_assume_index_observe)}, s, m) return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) @@ -248,7 +248,7 @@ end m ~ MvNormal(zero(x), Diagonal(s)) x ~ MvNormal(m, Diagonal(s)) - return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x) end function logprior_true(model::Model{typeof(demo_assume_multivariate_observe)}, s, m) s_dist = product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)]) @@ -279,7 +279,7 @@ end x[i] ~ Normal(m[i], sqrt(s[i])) end - return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x) end function logprior_true(model::Model{typeof(demo_dot_assume_observe_index)}, s, m) return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) @@ -304,7 +304,7 @@ end m ~ Normal(0, sqrt(s)) x .~ Normal(m, sqrt(s)) - return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x) end function logprior_true(model::Model{typeof(demo_assume_dot_observe)}, s, m) return logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, sqrt(s)), m) @@ -327,7 +327,7 @@ end m ~ MvNormal(zeros(2), Diagonal(s)) [1.5, 2.0] ~ MvNormal(m, Diagonal(s)) - return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=[1.5, 2.0]) end function logprior_true(model::Model{typeof(demo_assume_multivariate_observe_literal)}, s, m) s_dist = product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)]) @@ -358,7 +358,7 @@ end 1.5 ~ Normal(m[1], sqrt(s[1])) 2.0 ~ Normal(m[2], sqrt(s[2])) - return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=[1.5, 2.0]) end function logprior_true(model::Model{typeof(demo_dot_assume_observe_index_literal)}, s, m) return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) @@ -384,7 +384,7 @@ end 1.5 ~ Normal(m, sqrt(s)) 2.0 ~ Normal(m, sqrt(s)) - return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=[1.5, 2.0]) end function logprior_true(model::Model{typeof(demo_assume_observe_literal)}, s, m) return logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, sqrt(s)), m) @@ -407,7 +407,7 @@ end m ~ Normal(0, sqrt(s)) [1.5, 2.0] .~ Normal(m, sqrt(s)) - return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=[1.5, 2.0]) end function logprior_true(model::Model{typeof(demo_assume_dot_observe_literal)}, s, m) return logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, sqrt(s)), m) @@ -440,7 +440,7 @@ end 1.5 ~ Normal(m[1], sqrt(s[1])) 2.0 ~ Normal(m[2], sqrt(s[2])) - return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=[1.5, 2.0]) end function logprior_true( model::Model{typeof(demo_assume_submodel_observe_index_literal)}, s, m @@ -476,9 +476,9 @@ end # Submodel likelihood # With to_submodel, we have to have a left-hand side variable to # capture the result, so we just use a dummy variable - _ignore ~ to_submodel(_likelihood_multivariate_observe(s, m, x)) + _ignore ~ to_submodel(_likelihood_multivariate_observe(s, m, x), false) - return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x) end function logprior_true(model::Model{typeof(demo_dot_assume_observe_submodel)}, s, m) return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) @@ -505,7 +505,7 @@ end x[:, 1] ~ MvNormal(m, Diagonal(s)) - return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x) end function logprior_true(model::Model{typeof(demo_dot_assume_observe_matrix_index)}, s, m) return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) @@ -535,7 +535,7 @@ end x[:, 1] ~ MvNormal(m, Diagonal(s_vec)) - return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x) end function logprior_true(model::Model{typeof(demo_assume_matrix_observe_matrix_index)}, s, m) n = length(model.args.x) diff --git a/src/test_utils/varinfo.jl b/src/test_utils/varinfo.jl index 539872143..07a308c7a 100644 --- a/src/test_utils/varinfo.jl +++ b/src/test_utils/varinfo.jl @@ -37,12 +37,6 @@ function setup_varinfos( svi_untyped = SimpleVarInfo(OrderedDict()) svi_vnv = SimpleVarInfo(DynamicPPL.VarNamedVector()) - # SimpleVarInfo{<:Any,<:Ref} - svi_typed_ref = SimpleVarInfo(example_values, Ref(getlogp(svi_typed))) - svi_untyped_ref = SimpleVarInfo(OrderedDict(), Ref(getlogp(svi_untyped))) - svi_vnv_ref = SimpleVarInfo(DynamicPPL.VarNamedVector(), Ref(getlogp(svi_vnv))) - - lp = getlogp(vi_typed_metadata) varinfos = map(( vi_untyped_metadata, vi_untyped_vnv, @@ -51,12 +45,10 @@ function setup_varinfos( svi_typed, svi_untyped, svi_vnv, - svi_typed_ref, - svi_untyped_ref, - svi_vnv_ref, )) do vi - # Set them all to the same values. - DynamicPPL.setlogp!!(update_values!!(vi, example_values, varnames), lp) + # Set them all to the same values and evaluate logp. + vi = update_values!!(vi, example_values, varnames) + last(DynamicPPL.evaluate!!(model, vi, DefaultContext())) end if include_threadsafe diff --git a/src/threadsafe.jl b/src/threadsafe.jl index 2dc2645de..7d2d768a6 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -2,69 +2,79 @@ ThreadSafeVarInfo A `ThreadSafeVarInfo` object wraps an [`AbstractVarInfo`](@ref) object and an -array of log probabilities for thread-safe execution of a probabilistic model. +array of accumulators for thread-safe execution of a probabilistic model. """ -struct ThreadSafeVarInfo{V<:AbstractVarInfo,L} <: AbstractVarInfo +struct ThreadSafeVarInfo{V<:AbstractVarInfo,L<:AccumulatorTuple} <: AbstractVarInfo varinfo::V - logps::L + accs_by_thread::Vector{L} end function ThreadSafeVarInfo(vi::AbstractVarInfo) - return ThreadSafeVarInfo(vi, [Ref(zero(getlogp(vi))) for _ in 1:Threads.nthreads()]) + accs_by_thread = [map(split, getaccs(vi)) for _ in 1:Threads.nthreads()] + return ThreadSafeVarInfo(vi, accs_by_thread) end ThreadSafeVarInfo(vi::ThreadSafeVarInfo) = vi -const ThreadSafeVarInfoWithRef{V<:AbstractVarInfo} = ThreadSafeVarInfo{ - V,<:AbstractArray{<:Ref} -} - transformation(vi::ThreadSafeVarInfo) = transformation(vi.varinfo) -# Instead of updating the log probability of the underlying variables we -# just update the array of log probabilities. -function acclogp!!(vi::ThreadSafeVarInfo, logp) - vi.logps[Threads.threadid()] += logp - return vi +# Set the accumulator in question in vi.varinfo, and set the thread-specific +# accumulators of the same type to be empty. +function setacc!!(vi::ThreadSafeVarInfo, acc::AbstractAccumulator) + inner_vi = setacc!!(vi.varinfo, acc) + news_accs_by_thread = map(accs -> setacc!!(accs, split(acc)), vi.accs_by_thread) + return ThreadSafeVarInfo(inner_vi, news_accs_by_thread) end -function acclogp!!(vi::ThreadSafeVarInfoWithRef, logp) - vi.logps[Threads.threadid()][] += logp - return vi + +# Get both the main accumulator and the thread-specific accumulators of the same type and +# combine them. +function getacc(vi::ThreadSafeVarInfo, accname::Val) + main_acc = getacc(vi.varinfo, accname) + other_accs = map(accs -> getacc(accs, accname), vi.accs_by_thread) + return foldl(combine, other_accs; init=main_acc) end -# The current log probability of the variables has to be computed from -# both the wrapped variables and the thread-specific log probabilities. -getlogp(vi::ThreadSafeVarInfo) = getlogp(vi.varinfo) + sum(vi.logps) -getlogp(vi::ThreadSafeVarInfoWithRef) = getlogp(vi.varinfo) + sum(getindex, vi.logps) +hasacc(vi::ThreadSafeVarInfo, accname::Val) = hasacc(vi.varinfo, accname) +acckeys(vi::ThreadSafeVarInfo) = acckeys(vi.varinfo) -# TODO: Make remaining methods thread-safe. -function resetlogp!!(vi::ThreadSafeVarInfo) - return ThreadSafeVarInfo(resetlogp!!(vi.varinfo), zero(vi.logps)) +function getaccs(vi::ThreadSafeVarInfo) + # This method is a bit finicky to maintain type stability. For instance, moving the + # accname -> Val(accname) part in the main `map` call makes constant propagation fail + # and this becomes unstable. Do check the effects if you make edits. + accnames = acckeys(vi) + accname_vals = map(Val, accnames) + return AccumulatorTuple(map(anv -> getacc(vi, anv), accname_vals)) end -function resetlogp!!(vi::ThreadSafeVarInfoWithRef) - for x in vi.logps - x[] = zero(x[]) - end - return ThreadSafeVarInfo(resetlogp!!(vi.varinfo), vi.logps) -end -function setlogp!!(vi::ThreadSafeVarInfo, logp) - return ThreadSafeVarInfo(setlogp!!(vi.varinfo, logp), zero(vi.logps)) + +# Calls to map_accumulator(s)!! are thread-specific by default. For any use of them that +# should _not_ be thread-specific a specific method has to be written. +function map_accumulator!!(func::Function, vi::ThreadSafeVarInfo, accname::Val) + tid = Threads.threadid() + vi.accs_by_thread[tid] = map_accumulator(func, vi.accs_by_thread[tid], accname) + return vi end -function setlogp!!(vi::ThreadSafeVarInfoWithRef, logp) - for x in vi.logps - x[] = zero(x[]) - end - return ThreadSafeVarInfo(setlogp!!(vi.varinfo, logp), vi.logps) + +function map_accumulators!!(func::Function, vi::ThreadSafeVarInfo) + tid = Threads.threadid() + vi.accs_by_thread[tid] = map(func, vi.accs_by_thread[tid]) + return vi end -has_varnamedvector(vi::DynamicPPL.ThreadSafeVarInfo) = has_varnamedvector(vi.varinfo) +has_varnamedvector(vi::ThreadSafeVarInfo) = has_varnamedvector(vi.varinfo) function BangBang.push!!(vi::ThreadSafeVarInfo, vn::VarName, r, dist::Distribution) return Accessors.@set vi.varinfo = push!!(vi.varinfo, vn, r, dist) end +# TODO(mhauru) Why these short-circuits? Why not use the thread-specific ones? get_num_produce(vi::ThreadSafeVarInfo) = get_num_produce(vi.varinfo) -increment_num_produce!(vi::ThreadSafeVarInfo) = increment_num_produce!(vi.varinfo) -reset_num_produce!(vi::ThreadSafeVarInfo) = reset_num_produce!(vi.varinfo) -set_num_produce!(vi::ThreadSafeVarInfo, n::Int) = set_num_produce!(vi.varinfo, n) +function increment_num_produce!!(vi::ThreadSafeVarInfo) + return ThreadSafeVarInfo(increment_num_produce!!(vi.varinfo), vi.accs_by_thread) +end +function reset_num_produce!!(vi::ThreadSafeVarInfo) + return ThreadSafeVarInfo(reset_num_produce!!(vi.varinfo), vi.accs_by_thread) +end +function set_num_produce!!(vi::ThreadSafeVarInfo, n::Int) + return ThreadSafeVarInfo(set_num_produce!!(vi.varinfo, n), vi.accs_by_thread) +end syms(vi::ThreadSafeVarInfo) = syms(vi.varinfo) @@ -94,8 +104,8 @@ end # Need to define explicitly for `DynamicTransformation` to avoid method ambiguity. # NOTE: We also can't just defer to the wrapped varinfo, because we need to ensure -# consistency between `vi.logps` field and `getlogp(vi.varinfo)`, which accumulates -# to define `getlogp(vi)`. +# consistency between `vi.accs_by_thread` field and `getacc(vi.varinfo)`, which accumulates +# to define `getacc(vi)`. function link!!(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) return settrans!!(last(evaluate!!(model, vi, DynamicTransformationContext{false}())), t) end @@ -130,9 +140,9 @@ end function maybe_invlink_before_eval!!(vi::ThreadSafeVarInfo, model::Model) # Defer to the wrapped `AbstractVarInfo` object. - # NOTE: When computing `getlogp` for `ThreadSafeVarInfo` we do include the - # `getlogp(vi.varinfo)` hence the log-absdet-jacobian term will correctly be included in - # the `getlogp(vi)`. + # NOTE: When computing `getacc` for `ThreadSafeVarInfo` we do include the + # `getacc(vi.varinfo)` hence the log-absdet-jacobian term will correctly be included in + # the `getlogprior(vi)`. return Accessors.@set vi.varinfo = maybe_invlink_before_eval!!(vi.varinfo, model) end @@ -169,6 +179,23 @@ function BangBang.empty!!(vi::ThreadSafeVarInfo) return resetlogp!!(Accessors.@set(vi.varinfo = empty!!(vi.varinfo))) end +function resetlogp!!(vi::ThreadSafeVarInfo) + vi = Accessors.@set vi.varinfo = resetlogp!!(vi.varinfo) + for i in eachindex(vi.accs_by_thread) + if hasacc(vi, Val(:LogPrior)) + vi.accs_by_thread[i] = map_accumulator( + zero, vi.accs_by_thread[i], Val(:LogPrior) + ) + end + if hasacc(vi, Val(:LogLikelihood)) + vi.accs_by_thread[i] = map_accumulator( + zero, vi.accs_by_thread[i], Val(:LogLikelihood) + ) + end + end + return vi +end + values_as(vi::ThreadSafeVarInfo) = values_as(vi.varinfo) values_as(vi::ThreadSafeVarInfo, ::Type{T}) where {T} = values_as(vi.varinfo, T) diff --git a/src/transforming.jl b/src/transforming.jl index 429562ec8..ddd1ab59f 100644 --- a/src/transforming.jl +++ b/src/transforming.jl @@ -27,18 +27,47 @@ function tilde_assume( # Only transform if `!isinverse` since `vi[vn, right]` # already performs the inverse transformation if it's transformed. r_transformed = isinverse ? r : link_transform(right)(r) - return r, lp, setindex!!(vi, r_transformed, vn) + if hasacc(vi, Val(:LogPrior)) + vi = acclogprior!!(vi, lp) + end + return r, setindex!!(vi, r_transformed, vn) +end + +function tilde_observe!!(::DynamicTransformationContext, right, left, vn, vi) + return tilde_observe!!(DefaultContext(), right, left, vn, vi) end function link!!(t::DynamicTransformation, vi::AbstractVarInfo, model::Model) - return settrans!!(last(evaluate!!(model, vi, DynamicTransformationContext{false}())), t) + return _transform!!(t, DynamicTransformationContext{false}(), vi, model) end function invlink!!(::DynamicTransformation, vi::AbstractVarInfo, model::Model) - return settrans!!( - last(evaluate!!(model, vi, DynamicTransformationContext{true}())), - NoTransformation(), - ) + return _transform!!(NoTransformation(), DynamicTransformationContext{true}(), vi, model) +end + +function _transform!!( + t::AbstractTransformation, + ctx::DynamicTransformationContext, + vi::AbstractVarInfo, + model::Model, +) + # To transform using DynamicTransformationContext, we evaluate the model, but we do not + # need to use any accumulators other than LogPriorAccumulator (which is affected by the Jacobian of + # the transformation). + accs = getaccs(vi) + has_logprior = haskey(accs, Val(:LogPrior)) + if has_logprior + old_logprior = getacc(accs, Val(:LogPrior)) + vi = setaccs!!(vi, (old_logprior,)) + end + vi = settrans!!(last(evaluate!!(model, vi, ctx)), t) + # Restore the accumulators. + if has_logprior + new_logprior = getacc(vi, Val(:LogPrior)) + accs = setacc!!(accs, new_logprior) + end + vi = setaccs!!(vi, accs) + return vi end function link(t::DynamicTransformation, vi::AbstractVarInfo, model::Model) diff --git a/src/utils.jl b/src/utils.jl index 71919480c..a141148a0 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -18,7 +18,85 @@ const LogProbType = float(Real) """ @addlogprob!(ex) -Add the result of the evaluation of `ex` to the joint log probability. +Add a term to the log joint. + +If `ex` evaluates to a `NamedTuple` with keys `:loglikelihood` and/or `:logprior`, the +values are added to the log likelihood and log prior respectively. + +If `ex` evaluates to a number it is added to the log likelihood. This use is deprecated +and should be replaced with either the `NamedTuple` version or calls to +[`@addloglikelihood!`](@ref). + +See also [`@addloglikelihood!`](@ref), [`@addlogprior!`](@ref). + +# Examples + +```jldoctest; setup = :(using Distributions) +julia> mylogjoint(x, μ) = (; loglikelihood=loglikelihood(Normal(μ, 1), x), logprior=1.0); + +julia> @model function demo(x) + μ ~ Normal() + @addlogprob! mylogjoint(x, μ) + end; + +julia> x = [1.3, -2.1]; + +julia> loglikelihood(demo(x), (μ=0.2,)) ≈ mylogjoint(x, 0.2).loglikelihood +true + +julia> logprior(demo(x), (μ=0.2,)) ≈ logpdf(Normal(), 0.2) + mylogjoint(x, 0.2).logprior +true +``` + +and to [reject samples](https://github.com/TuringLang/Turing.jl/issues/1328): + +```jldoctest; setup = :(using Distributions, LinearAlgebra) +julia> @model function demo(x) + m ~ MvNormal(zero(x), I) + if dot(m, x) < 0 + @addlogprob! (; loglikelihood=-Inf) + # Exit the model evaluation early + return + end + x ~ MvNormal(m, I) + return + end; + +julia> logjoint(demo([-2.1]), (m=[0.2],)) == -Inf +true +``` +""" +macro addlogprob!(ex) + return quote + val = $(esc(ex)) + vi = $(esc(:(__varinfo__))) + if val isa Number + Base.depwarn( + """ + @addlogprob! with a single number argument is deprecated. Please use + @addlogprob! (; loglikelihood=x) or @addloglikelihood! instead. + """, + :addlogprob!, + ) + if hasacc(vi, Val(:LogLikelihood)) + $(esc(:(__varinfo__))) = accloglikelihood!!($(esc(:(__varinfo__))), val) + end + elseif !isa(val, NamedTuple) + error("logp must be a NamedTuple.") + else + $(esc(:(__varinfo__))) = acclogp!!( + $(esc(:(__varinfo__))), val; ignore_missing_accumulator=true + ) + end + end +end + +""" + @addloglikelihood!(ex) + +Add the result of the evaluation of `ex` to the log likelihood. + +See also [`@addlogprob!`](@ref), [`@addlogprior!`](@ref). # Examples @@ -29,7 +107,7 @@ julia> myloglikelihood(x, μ) = loglikelihood(Normal(μ, 1), x); julia> @model function demo(x) μ ~ Normal() - @addlogprob! myloglikelihood(x, μ) + @addloglikelihood! myloglikelihood(x, μ) end; julia> x = [1.3, -2.1]; @@ -44,7 +122,7 @@ and to [reject samples](https://github.com/TuringLang/Turing.jl/issues/1328): julia> @model function demo(x) m ~ MvNormal(zero(x), I) if dot(m, x) < 0 - @addlogprob! -Inf + @addloglikelihood! -Inf # Exit the model evaluation early return end @@ -55,37 +133,45 @@ julia> @model function demo(x) julia> logjoint(demo([-2.1]), (m=[0.2],)) == -Inf true ``` +""" +macro addloglikelihood!(ex) + return quote + if hasacc($(esc(:(__varinfo__))), Val(:LogLikelihood)) + $(esc(:(__varinfo__))) = accloglikelihood!!($(esc(:(__varinfo__))), $(esc(ex))) + end + end +end -!!! note - The `@addlogprob!` macro increases the accumulated log probability regardless of the evaluation context, - i.e., regardless of whether you evaluate the log prior, the log likelihood or the log joint density. - If you would like to avoid this behaviour you should check the evaluation context. - It can be accessed with the internal variable `__context__`. - For instance, in the following example the log density is not accumulated when only the log prior is computed: - ```jldoctest; setup = :(using Distributions) - julia> myloglikelihood(x, μ) = loglikelihood(Normal(μ, 1), x); +""" + @addlogprior!(ex) + +Add the result of the evaluation of `ex` to the log prior. + +See also [`@addloglikelihood!`](@ref), [`@addlogprob!`](@ref). + +# Examples + +This macro allows you to include arbitrary terms in the prior. - julia> @model function demo(x) - μ ~ Normal() - if DynamicPPL.leafcontext(__context__) !== PriorContext() - @addlogprob! myloglikelihood(x, μ) - end - end; +```jldoctest; setup = :(using Distributions) +julia> mylogpriorextraterm(μ) = μ > 0 ? -1.0 : 0.0; - julia> x = [1.3, -2.1]; +julia> @model function demo(x) + μ ~ Normal() + @addlogprior! mylogpriorextraterm(μ) + end; - julia> logprior(demo(x), (μ=0.2,)) ≈ logpdf(Normal(), 0.2) - true +julia> x = [1.3, -2.1]; - julia> loglikelihood(demo(x), (μ=0.2,)) ≈ myloglikelihood(x, 0.2) - true - ``` +julia> logprior(demo(x), (μ=0.2,)) ≈ logpdf(Normal(), 0.2) + mylogpriorextraterm(0.2) +true +``` """ -macro addlogprob!(ex) +macro addlogprior!(ex) return quote - $(esc(:(__varinfo__))) = acclogp!!( - $(esc(:(__context__))), $(esc(:(__varinfo__))), $(esc(ex)) - ) + if hasacc($(esc(:(__varinfo__))), Val(:LogPrior)) + $(esc(:(__varinfo__))) = acclogprior!!($(esc(:(__varinfo__))), $(esc(ex))) + end end end diff --git a/src/values_as_in_model.jl b/src/values_as_in_model.jl index d3bfd697a..3ec474940 100644 --- a/src/values_as_in_model.jl +++ b/src/values_as_in_model.jl @@ -65,29 +65,24 @@ end function tilde_assume(context::ValuesAsInModelContext, right, vn, vi) if is_tracked_value(right) value = right.value - logp = zero(getlogp(vi)) else - value, logp, vi = tilde_assume(childcontext(context), right, vn, vi) + value, vi = tilde_assume(childcontext(context), right, vn, vi) end - # Save the value. push!(context, vn, value) - # Save the value. - # Pass on. - return value, logp, vi + return value, vi end function tilde_assume( rng::Random.AbstractRNG, context::ValuesAsInModelContext, sampler, right, vn, vi ) if is_tracked_value(right) value = right.value - logp = zero(getlogp(vi)) else - value, logp, vi = tilde_assume(rng, childcontext(context), sampler, right, vn, vi) + value, vi = tilde_assume(rng, childcontext(context), sampler, right, vn, vi) end # Save the value. push!(context, vn, value) # Pass on. - return value, logp, vi + return value, vi end """ diff --git a/src/varinfo.jl b/src/varinfo.jl index 360857ef7..ec55f6476 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -69,10 +69,9 @@ end ########### """ - struct VarInfo{Tmeta,Tlogp} <: AbstractVarInfo + struct VarInfo{Tmeta,Accs<:AccumulatorTuple} <: AbstractVarInfo metadata::Tmeta - logp::Base.RefValue{Tlogp} - num_produce::Base.RefValue{Int} + accs::Accs end A light wrapper over some kind of metadata. @@ -98,12 +97,21 @@ Note that for NTVarInfo, it is the user's responsibility to ensure that each symbol is visited at least once during model evaluation, regardless of any stochastic branching. """ -struct VarInfo{Tmeta,Tlogp} <: AbstractVarInfo +struct VarInfo{Tmeta,Accs<:AccumulatorTuple} <: AbstractVarInfo metadata::Tmeta - logp::Base.RefValue{Tlogp} - num_produce::Base.RefValue{Int} + accs::Accs end -VarInfo(meta=Metadata()) = VarInfo(meta, Ref{LogProbType}(0.0), Ref(0)) +function VarInfo(meta=Metadata()) + return VarInfo( + meta, + AccumulatorTuple( + LogPriorAccumulator{LogProbType}(), + LogLikelihoodAccumulator{LogProbType}(), + NumProduceAccumulator{Int}(), + ), + ) +end + """ VarInfo([rng, ]model[, sampler, context]) @@ -285,10 +293,8 @@ function typed_varinfo(vi::UntypedVarInfo) ), ) end - logp = getlogp(vi) - num_produce = get_num_produce(vi) nt = NamedTuple{syms_tuple}(Tuple(new_metas)) - return VarInfo(nt, Ref(logp), Ref(num_produce)) + return VarInfo(nt, deepcopy(vi.accs)) end function typed_varinfo(vi::NTVarInfo) # This function preserves the behaviour of typed_varinfo(vi) where vi is @@ -349,8 +355,7 @@ single `VarNamedVector` as its metadata field. """ function untyped_vector_varinfo(vi::UntypedVarInfo) md = metadata_to_varnamedvector(vi.metadata) - lp = getlogp(vi) - return VarInfo(md, Base.RefValue{eltype(lp)}(lp), Ref(get_num_produce(vi))) + return VarInfo(md, deepcopy(vi.accs)) end function untyped_vector_varinfo( rng::Random.AbstractRNG, @@ -393,15 +398,12 @@ NamedTuple of `VarNamedVector`s as its metadata field. """ function typed_vector_varinfo(vi::NTVarInfo) md = map(metadata_to_varnamedvector, vi.metadata) - lp = getlogp(vi) - return VarInfo(md, Base.RefValue{eltype(lp)}(lp), Ref(get_num_produce(vi))) + return VarInfo(md, deepcopy(vi.accs)) end function typed_vector_varinfo(vi::UntypedVectorVarInfo) new_metas = group_by_symbol(vi.metadata) - logp = getlogp(vi) - num_produce = get_num_produce(vi) nt = NamedTuple(new_metas) - return VarInfo(nt, Ref(logp), Ref(num_produce)) + return VarInfo(nt, deepcopy(vi.accs)) end function typed_vector_varinfo( rng::Random.AbstractRNG, @@ -441,13 +443,17 @@ vector_length(md::Metadata) = sum(length, md.ranges) function unflatten(vi::VarInfo, x::AbstractVector) md = unflatten_metadata(vi.metadata, x) - # Note that use of RefValue{eltype(x)} rather than Ref is necessary to deal with cases - # where e.g. x is a type gradient of some AD backend. - return VarInfo( - md, - Base.RefValue{float_type_with_fallback(eltype(x))}(getlogp(vi)), - Ref(get_num_produce(vi)), - ) + # Use of float_type_with_fallback(eltype(x)) is necessary to deal with cases where x is + # a gradient type of some AD backend. + # TODO(mhauru) How could we do this more cleanly? The problem case is map_accumulator!! + # for ThreadSafeVarInfo. In that one, if the map produces e.g a ForwardDiff.Dual, but + # the accumulators in the VarInfo are plain floats, we error since we can't change the + # element type of ThreadSafeVarInfo.accs_by_thread. However, doing this conversion here + # messes with cases like using Float32 of logprobs and Float64 for x. Also, this is just + # plain ugly and hacky. + T = float_type_with_fallback(eltype(x)) + accs = map(acc -> convert_eltype(T, acc), deepcopy(getaccs(vi))) + return VarInfo(md, accs) end # We would call this `unflatten` if not for `unflatten` having a method for NamedTuples in @@ -529,7 +535,7 @@ end function subset(varinfo::VarInfo, vns::AbstractVector{<:VarName}) metadata = subset(varinfo.metadata, vns) - return VarInfo(metadata, deepcopy(varinfo.logp), deepcopy(varinfo.num_produce)) + return VarInfo(metadata, deepcopy(varinfo.accs)) end function subset(metadata::NamedTuple, vns::AbstractVector{<:VarName}) @@ -618,9 +624,7 @@ end function _merge(varinfo_left::VarInfo, varinfo_right::VarInfo) metadata = merge_metadata(varinfo_left.metadata, varinfo_right.metadata) - return VarInfo( - metadata, Ref(getlogp(varinfo_right)), Ref(get_num_produce(varinfo_right)) - ) + return VarInfo(metadata, deepcopy(varinfo_right.accs)) end function merge_metadata(vnv_left::VarNamedVector, vnv_right::VarNamedVector) @@ -973,8 +977,8 @@ end function BangBang.empty!!(vi::VarInfo) _empty!(vi.metadata) - resetlogp!!(vi) - reset_num_produce!(vi) + vi = resetlogp!!(vi) + vi = reset_num_produce!!(vi) return vi end @@ -1008,46 +1012,37 @@ end istrans(vi::VarInfo, vn::VarName) = istrans(getmetadata(vi, vn), vn) istrans(md::Metadata, vn::VarName) = is_flagged(md, vn, "trans") -getlogp(vi::VarInfo) = vi.logp[] - -function setlogp!!(vi::VarInfo, logp) - vi.logp[] = logp - return vi -end - -function acclogp!!(vi::VarInfo, logp) - vi.logp[] += logp - return vi -end +getaccs(vi::VarInfo) = vi.accs +setaccs!!(vi::VarInfo, accs::AccumulatorTuple) = Accessors.@set vi.accs = accs """ get_num_produce(vi::VarInfo) Return the `num_produce` of `vi`. """ -get_num_produce(vi::VarInfo) = vi.num_produce[] +get_num_produce(vi::VarInfo) = getacc(vi, Val(:NumProduce)).num """ - set_num_produce!(vi::VarInfo, n::Int) + set_num_produce!!(vi::VarInfo, n::Int) Set the `num_produce` field of `vi` to `n`. """ -set_num_produce!(vi::VarInfo, n::Int) = vi.num_produce[] = n +set_num_produce!!(vi::VarInfo, n::Int) = setacc!!(vi, NumProduceAccumulator(n)) """ - increment_num_produce!(vi::VarInfo) + increment_num_produce!!(vi::VarInfo) Add 1 to `num_produce` in `vi`. """ -increment_num_produce!(vi::VarInfo) = vi.num_produce[] += 1 +increment_num_produce!!(vi::VarInfo) = map_accumulator!!(increment, vi, Val(:NumProduce)) """ - reset_num_produce!(vi::VarInfo) + reset_num_produce!!(vi::VarInfo) Reset the value of `num_produce` the log of the joint probability of the observed data and parameters sampled in `vi` to 0. """ -reset_num_produce!(vi::VarInfo) = set_num_produce!(vi, 0) +reset_num_produce!!(vi::VarInfo) = map_accumulator!!(zero, vi, Val(:NumProduce)) # Need to introduce the _isempty to avoid type piracy of isempty(::NamedTuple). isempty(vi::VarInfo) = _isempty(vi.metadata) @@ -1061,7 +1056,7 @@ function link!!(::DynamicTransformation, vi::NTVarInfo, model::Model) vns = all_varnames_grouped_by_symbol(vi) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return _link(model, vi, vns) - _link!(vi, vns) + vi = _link!!(vi, vns) return vi end @@ -1069,7 +1064,7 @@ function link!!(::DynamicTransformation, vi::VarInfo, model::Model) vns = keys(vi) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return _link(model, vi, vns) - _link!(vi, vns) + vi = _link!!(vi, vns) return vi end @@ -1082,8 +1077,7 @@ end function link!!(::DynamicTransformation, vi::VarInfo, vns::VarNameTuple, model::Model) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return _link(model, vi, vns) - # Call `_link!` instead of `link!` to avoid deprecation warning. - _link!(vi, vns) + vi = _link!!(vi, vns) return vi end @@ -1098,27 +1092,28 @@ function link!!( return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, vns, model) end -function _link!(vi::UntypedVarInfo, vns) +function _link!!(vi::UntypedVarInfo, vns) # TODO: Change to a lazy iterator over `vns` if ~istrans(vi, vns[1]) for vn in vns f = internal_to_linked_internal_transform(vi, vn) - _inner_transform!(vi, vn, f) - settrans!!(vi, true, vn) + vi = _inner_transform!(vi, vn, f) + vi = settrans!!(vi, true, vn) end + return vi else @warn("[DynamicPPL] attempt to link a linked vi") end end -# If we try to _link! a NTVarInfo with a Tuple of VarNames, first convert it to a +# If we try to _link!! a NTVarInfo with a Tuple of VarNames, first convert it to a # NamedTuple that matches the structure of the NTVarInfo. -function _link!(vi::NTVarInfo, vns::VarNameTuple) - return _link!(vi, group_varnames_by_symbol(vns)) +function _link!!(vi::NTVarInfo, vns::VarNameTuple) + return _link!!(vi, group_varnames_by_symbol(vns)) end -function _link!(vi::NTVarInfo, vns::NamedTuple) - return _link!(vi.metadata, vi, vns) +function _link!!(vi::NTVarInfo, vns::NamedTuple) + return _link!!(vi.metadata, vi, vns) end """ @@ -1130,7 +1125,7 @@ function filter_subsumed(filter_vns, filtered_vns) return filter(x -> any(subsumes(y, x) for y in filter_vns), filtered_vns) end -@generated function _link!( +@generated function _link!!( ::NamedTuple{metadata_names}, vi, vns::NamedTuple{vns_names} ) where {metadata_names,vns_names} expr = Expr(:block) @@ -1148,8 +1143,8 @@ end # Iterate over all `f_vns` and transform for vn in f_vns f = internal_to_linked_internal_transform(vi, vn) - _inner_transform!(vi, vn, f) - settrans!!(vi, true, vn) + vi = _inner_transform!(vi, vn, f) + vi = settrans!!(vi, true, vn) end else @warn("[DynamicPPL] attempt to link a linked vi") @@ -1158,6 +1153,7 @@ end end, ) end + push!(expr.args, :(return vi)) return expr end @@ -1165,8 +1161,7 @@ function invlink!!(::DynamicTransformation, vi::NTVarInfo, model::Model) vns = all_varnames_grouped_by_symbol(vi) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return _invlink(model, vi, vns) - # Call `_invlink!` instead of `invlink!` to avoid deprecation warning. - _invlink!(vi, vns) + vi = _invlink!!(vi, vns) return vi end @@ -1174,7 +1169,7 @@ function invlink!!(::DynamicTransformation, vi::VarInfo, model::Model) vns = keys(vi) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return _invlink(model, vi, vns) - _invlink!(vi, vns) + vi = _invlink!!(vi, vns) return vi end @@ -1187,8 +1182,7 @@ end function invlink!!(::DynamicTransformation, vi::VarInfo, vns::VarNameTuple, model::Model) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return _invlink(model, vi, vns) - # Call `_invlink!` instead of `invlink!` to avoid deprecation warning. - _invlink!(vi, vns) + vi = _invlink!!(vi, vns) return vi end @@ -1211,29 +1205,30 @@ function maybe_invlink_before_eval!!(vi::VarInfo, model::Model) return maybe_invlink_before_eval!!(t, vi, model) end -function _invlink!(vi::UntypedVarInfo, vns) +function _invlink!!(vi::UntypedVarInfo, vns) if istrans(vi, vns[1]) for vn in vns f = linked_internal_to_internal_transform(vi, vn) - _inner_transform!(vi, vn, f) - settrans!!(vi, false, vn) + vi = _inner_transform!(vi, vn, f) + vi = settrans!!(vi, false, vn) end + return vi else @warn("[DynamicPPL] attempt to invlink an invlinked vi") end end -# If we try to _invlink! a NTVarInfo with a Tuple of VarNames, first convert it to a +# If we try to _invlink!! a NTVarInfo with a Tuple of VarNames, first convert it to a # NamedTuple that matches the structure of the NTVarInfo. -function _invlink!(vi::NTVarInfo, vns::VarNameTuple) - return _invlink!(vi.metadata, vi, group_varnames_by_symbol(vns)) +function _invlink!!(vi::NTVarInfo, vns::VarNameTuple) + return _invlink!!(vi.metadata, vi, group_varnames_by_symbol(vns)) end -function _invlink!(vi::NTVarInfo, vns::NamedTuple) - return _invlink!(vi.metadata, vi, vns) +function _invlink!!(vi::NTVarInfo, vns::NamedTuple) + return _invlink!!(vi.metadata, vi, vns) end -@generated function _invlink!( +@generated function _invlink!!( ::NamedTuple{metadata_names}, vi, vns::NamedTuple{vns_names} ) where {metadata_names,vns_names} expr = Expr(:block) @@ -1251,8 +1246,8 @@ end # Iterate over all `f_vns` and transform for vn in f_vns f = linked_internal_to_internal_transform(vi, vn) - _inner_transform!(vi, vn, f) - settrans!!(vi, false, vn) + vi = _inner_transform!(vi, vn, f) + vi = settrans!!(vi, false, vn) end else @warn("[DynamicPPL] attempt to invlink an invlinked vi") @@ -1260,6 +1255,7 @@ end end, ) end + push!(expr.args, :(return vi)) return expr end @@ -1276,7 +1272,7 @@ function _inner_transform!(md::Metadata, vi::VarInfo, vn::VarName, f) setrange!(md, vn, start:(start + length(yvec) - 1)) # Set the new value. setval!(md, yvec, vn) - acclogp!!(vi, -logjac) + vi = acclogprior!!(vi, -logjac) return vi end @@ -1311,8 +1307,10 @@ end function _link(model::Model, varinfo::VarInfo, vns) varinfo = deepcopy(varinfo) - md = _link_metadata!!(model, varinfo, varinfo.metadata, vns) - return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) + md, logjac = _link_metadata!!(model, varinfo, varinfo.metadata, vns) + new_varinfo = VarInfo(md, varinfo.accs) + new_varinfo = acclogprior!!(new_varinfo, -logjac) + return new_varinfo end # If we try to _link a NTVarInfo with a Tuple of VarNames, first convert it to a @@ -1323,8 +1321,10 @@ end function _link(model::Model, varinfo::NTVarInfo, vns::NamedTuple) varinfo = deepcopy(varinfo) - md = _link_metadata!(model, varinfo, varinfo.metadata, vns) - return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) + md, logjac = _link_metadata!(model, varinfo, varinfo.metadata, vns) + new_varinfo = VarInfo(md, varinfo.accs) + new_varinfo = acclogprior!!(new_varinfo, -logjac) + return new_varinfo end @generated function _link_metadata!( @@ -1333,20 +1333,39 @@ end metadata::NamedTuple{metadata_names}, vns::NamedTuple{vns_names}, ) where {metadata_names,vns_names} - vals = Expr(:tuple) + expr = quote + cumulative_logjac = zero(LogProbType) + end + mds = Expr(:tuple) for f in metadata_names if f in vns_names - push!(vals.args, :(_link_metadata!!(model, varinfo, metadata.$f, vns.$f))) + push!( + mds.args, + quote + begin + md, logjac = _link_metadata!!(model, varinfo, metadata.$f, vns.$f) + cumulative_logjac += logjac + md + end + end, + ) else - push!(vals.args, :(metadata.$f)) + push!(mds.args, :(metadata.$f)) end end - return :(NamedTuple{$metadata_names}($vals)) + push!( + expr.args, + quote + NamedTuple{$metadata_names}($mds), cumulative_logjac + end, + ) + return expr end function _link_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, target_vns) vns = metadata.vns + cumulative_logjac = zero(LogProbType) # Construct the new transformed values, and keep track of their lengths. vals_new = map(vns) do vn @@ -1364,7 +1383,7 @@ function _link_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, target_ # Vectorize value. yvec = tovec(y) # Accumulate the log-abs-det jacobian correction. - acclogp!!(varinfo, -logjac) + cumulative_logjac += logjac # Mark as transformed. settrans!!(varinfo, true, vn) # Return the vectorized transformed value. @@ -1389,7 +1408,8 @@ function _link_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, target_ metadata.dists, metadata.orders, metadata.flags, - ) + ), + cumulative_logjac end function _link_metadata!!( @@ -1397,6 +1417,7 @@ function _link_metadata!!( ) vns = target_vns === nothing ? keys(metadata) : target_vns dists = extract_priors(model, varinfo) + cumulative_logjac = zero(LogProbType) for vn in vns # First transform from however the variable is stored in vnv to the model # representation. @@ -1409,11 +1430,11 @@ function _link_metadata!!( val_new, logjac2 = with_logabsdet_jacobian(transform_to_linked, val_orig) # TODO(mhauru) We are calling a !! function but ignoring the return value. # Fix this when attending to issue #653. - acclogp!!(varinfo, -logjac1 - logjac2) + cumulative_logjac += logjac1 + logjac2 metadata = setindex_internal!!(metadata, val_new, vn, transform_from_linked) settrans!(metadata, true, vn) end - return metadata + return metadata, cumulative_logjac end function invlink(::DynamicTransformation, vi::NTVarInfo, model::Model) @@ -1449,11 +1470,10 @@ end function _invlink(model::Model, varinfo::VarInfo, vns) varinfo = deepcopy(varinfo) - return VarInfo( - _invlink_metadata!!(model, varinfo, varinfo.metadata, vns), - Base.Ref(getlogp(varinfo)), - Ref(get_num_produce(varinfo)), - ) + md, logjac = _invlink_metadata!!(model, varinfo, varinfo.metadata, vns) + new_varinfo = VarInfo(md, varinfo.accs) + new_varinfo = acclogprior!!(new_varinfo, -logjac) + return new_varinfo end # If we try to _invlink a NTVarInfo with a Tuple of VarNames, first convert it to a @@ -1464,8 +1484,10 @@ end function _invlink(model::Model, varinfo::NTVarInfo, vns::NamedTuple) varinfo = deepcopy(varinfo) - md = _invlink_metadata!(model, varinfo, varinfo.metadata, vns) - return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) + md, logjac = _invlink_metadata!(model, varinfo, varinfo.metadata, vns) + new_varinfo = VarInfo(md, varinfo.accs) + new_varinfo = acclogprior!!(new_varinfo, -logjac) + return new_varinfo end @generated function _invlink_metadata!( @@ -1474,20 +1496,41 @@ end metadata::NamedTuple{metadata_names}, vns::NamedTuple{vns_names}, ) where {metadata_names,vns_names} - vals = Expr(:tuple) + expr = quote + cumulative_logjac = zero(LogProbType) + end + mds = Expr(:tuple) for f in metadata_names if (f in vns_names) - push!(vals.args, :(_invlink_metadata!!(model, varinfo, metadata.$f, vns.$f))) + push!( + mds.args, + quote + begin + md, logjac = _invlink_metadata!!( + model, varinfo, metadata.$f, vns.$f + ) + cumulative_logjac += logjac + md + end + end, + ) else - push!(vals.args, :(metadata.$f)) + push!(mds.args, :(metadata.$f)) end end - return :(NamedTuple{$metadata_names}($vals)) + push!( + expr.args, + quote + (NamedTuple{$metadata_names}($mds), cumulative_logjac) + end, + ) + return expr end function _invlink_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, target_vns) vns = metadata.vns + cumulative_logjac = zero(LogProbType) # Construct the new transformed values, and keep track of their lengths. vals_new = map(vns) do vn @@ -1506,7 +1549,7 @@ function _invlink_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, targ # Vectorize value. xvec = tovec(x) # Accumulate the log-abs-det jacobian correction. - acclogp!!(varinfo, -logjac) + cumulative_logjac += logjac # Mark as no longer transformed. settrans!!(varinfo, false, vn) # Return the vectorized transformed value. @@ -1531,24 +1574,26 @@ function _invlink_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, targ metadata.dists, metadata.orders, metadata.flags, - ) + ), + cumulative_logjac end function _invlink_metadata!!( ::Model, varinfo::VarInfo, metadata::VarNamedVector, target_vns ) vns = target_vns === nothing ? keys(metadata) : target_vns + cumulative_logjac = zero(LogProbType) for vn in vns transform = gettransform(metadata, vn) old_val = getindex_internal(metadata, vn) new_val, logjac = with_logabsdet_jacobian(transform, old_val) # TODO(mhauru) We are calling a !! function but ignoring the return value. - acclogp!!(varinfo, -logjac) + cumulative_logjac += logjac new_transform = from_vec_transform(new_val) metadata = setindex_internal!!(metadata, tovec(new_val), vn, new_transform) settrans!(metadata, false, vn) end - return metadata + return metadata, cumulative_logjac end # TODO(mhauru) The treatment of the case when some variables are linked and others are not @@ -1705,19 +1750,35 @@ function Base.haskey(vi::NTVarInfo, vn::VarName) end function Base.show(io::IO, ::MIME"text/plain", vi::UntypedVarInfo) - vi_str = """ - /======================================================================= - | VarInfo - |----------------------------------------------------------------------- - | Varnames : $(string(vi.metadata.vns)) - | Range : $(vi.metadata.ranges) - | Vals : $(vi.metadata.vals) - | Orders : $(vi.metadata.orders) - | Logp : $(getlogp(vi)) - | #produce : $(get_num_produce(vi)) - | flags : $(vi.metadata.flags) - \\======================================================================= - """ + lines = Tuple{String,Any}[ + ("VarNames", vi.metadata.vns), + ("Range", vi.metadata.ranges), + ("Vals", vi.metadata.vals), + ("Orders", vi.metadata.orders), + ] + for accname in acckeys(vi) + push!(lines, (string(accname), getacc(vi, Val(accname)))) + end + push!(lines, ("flags", vi.metadata.flags)) + max_name_length = maximum(map(length ∘ first, lines)) + fmt = Printf.Format("%-$(max_name_length)s") + vi_str = ( + """ + /======================================================================= + | VarInfo + |----------------------------------------------------------------------- + """ * + prod( + map(lines) do (name, value) + """ + | $(Printf.format(fmt, name)) : $(value) + """ + end, + ) * + """ + \\======================================================================= + """ + ) return print(io, vi_str) end @@ -1747,7 +1808,11 @@ end function Base.show(io::IO, vi::UntypedVarInfo) print(io, "VarInfo (") _show_varnames(io, vi) - print(io, "; logp: ", round(getlogp(vi); digits=3)) + print(io, "; accumulators: ") + # TODO(mhauru) This uses "text/plain" because we are doing quite a condensed repretation + # of vi anyway. However, technically `show(io, x)` should give full details of x and + # preferably output valid Julia code. + show(io, MIME"text/plain"(), getaccs(vi)) return print(io, ")") end diff --git a/test/accumulators.jl b/test/accumulators.jl new file mode 100644 index 000000000..36bb95e46 --- /dev/null +++ b/test/accumulators.jl @@ -0,0 +1,176 @@ +module AccumulatorTests + +using Test +using Distributions +using DynamicPPL +using DynamicPPL: + AccumulatorTuple, + LogLikelihoodAccumulator, + LogPriorAccumulator, + NumProduceAccumulator, + accumulate_assume!!, + accumulate_observe!!, + combine, + convert_eltype, + getacc, + increment, + map_accumulator, + setacc!!, + split + +@testset "accumulators" begin + @testset "individual accumulator types" begin + @testset "constructors" begin + @test LogPriorAccumulator(0.0) == + LogPriorAccumulator() == + LogPriorAccumulator{Float64}() == + LogPriorAccumulator{Float64}(0.0) == + zero(LogPriorAccumulator(1.0)) + @test LogLikelihoodAccumulator(0.0) == + LogLikelihoodAccumulator() == + LogLikelihoodAccumulator{Float64}() == + LogLikelihoodAccumulator{Float64}(0.0) == + zero(LogLikelihoodAccumulator(1.0)) + @test NumProduceAccumulator(0) == + NumProduceAccumulator() == + NumProduceAccumulator{Int}() == + NumProduceAccumulator{Int}(0) == + zero(NumProduceAccumulator(1)) + end + + @testset "addition and incrementation" begin + @test LogPriorAccumulator(1.0f0) + LogPriorAccumulator(1.0f0) == + LogPriorAccumulator(2.0f0) + @test LogPriorAccumulator(1.0) + LogPriorAccumulator(1.0f0) == + LogPriorAccumulator(2.0) + @test LogLikelihoodAccumulator(1.0f0) + LogLikelihoodAccumulator(1.0f0) == + LogLikelihoodAccumulator(2.0f0) + @test LogLikelihoodAccumulator(1.0) + LogLikelihoodAccumulator(1.0f0) == + LogLikelihoodAccumulator(2.0) + @test increment(NumProduceAccumulator()) == NumProduceAccumulator(1) + @test increment(NumProduceAccumulator{UInt8}()) == + NumProduceAccumulator{UInt8}(1) + end + + @testset "split and combine" begin + for acc in [ + LogPriorAccumulator(1.0), + LogLikelihoodAccumulator(1.0), + NumProduceAccumulator(1), + LogPriorAccumulator(1.0f0), + LogLikelihoodAccumulator(1.0f0), + NumProduceAccumulator(UInt8(1)), + ] + @test combine(acc, split(acc)) == acc + end + end + + @testset "conversions" begin + @test convert(LogPriorAccumulator{Float32}, LogPriorAccumulator(1.0)) == + LogPriorAccumulator{Float32}(1.0f0) + @test convert( + LogLikelihoodAccumulator{Float32}, LogLikelihoodAccumulator(1.0) + ) == LogLikelihoodAccumulator{Float32}(1.0f0) + @test convert(NumProduceAccumulator{UInt8}, NumProduceAccumulator(1)) == + NumProduceAccumulator{UInt8}(1) + + @test convert_eltype(Float32, LogPriorAccumulator(1.0)) == + LogPriorAccumulator{Float32}(1.0f0) + @test convert_eltype(Float32, LogLikelihoodAccumulator(1.0)) == + LogLikelihoodAccumulator{Float32}(1.0f0) + end + + @testset "accumulate_assume" begin + val = 2.0 + logjac = pi + vn = @varname(x) + dist = Normal() + @test accumulate_assume!!(LogPriorAccumulator(1.0), val, logjac, vn, dist) == + LogPriorAccumulator(1.0 + logjac + logpdf(dist, val)) + @test accumulate_assume!!( + LogLikelihoodAccumulator(1.0), val, logjac, vn, dist + ) == LogLikelihoodAccumulator(1.0) + @test accumulate_assume!!(NumProduceAccumulator(1), val, logjac, vn, dist) == + NumProduceAccumulator(1) + end + + @testset "accumulate_observe" begin + right = Normal() + left = 2.0 + vn = @varname(x) + @test accumulate_observe!!(LogPriorAccumulator(1.0), right, left, vn) == + LogPriorAccumulator(1.0) + @test accumulate_observe!!(LogLikelihoodAccumulator(1.0), right, left, vn) == + LogLikelihoodAccumulator(1.0 + logpdf(right, left)) + @test accumulate_observe!!(NumProduceAccumulator(1), right, left, vn) == + NumProduceAccumulator(2) + end + end + + @testset "accumulator tuples" begin + # Some accumulators we'll use for testing + lp_f64 = LogPriorAccumulator(1.0) + lp_f32 = LogPriorAccumulator(1.0f0) + ll_f64 = LogLikelihoodAccumulator(1.0) + ll_f32 = LogLikelihoodAccumulator(1.0f0) + np_i64 = NumProduceAccumulator(1) + + @testset "constructors" begin + @test AccumulatorTuple(lp_f64, ll_f64) == AccumulatorTuple((lp_f64, ll_f64)) + # Names in NamedTuple arguments are ignored + @test AccumulatorTuple((; a=lp_f64)) == AccumulatorTuple(lp_f64) + + # Can't have two accumulators of the same type. + @test_throws "duplicate field name" AccumulatorTuple(lp_f64, lp_f64) + # Not even if their element types differ. + @test_throws "duplicate field name" AccumulatorTuple(lp_f64, lp_f32) + end + + @testset "basic operations" begin + at_all64 = AccumulatorTuple(lp_f64, ll_f64, np_i64) + + @test at_all64[:LogPrior] == lp_f64 + @test at_all64[:LogLikelihood] == ll_f64 + @test at_all64[:NumProduce] == np_i64 + + @test haskey(AccumulatorTuple(np_i64), Val(:NumProduce)) + @test ~haskey(AccumulatorTuple(np_i64), Val(:LogPrior)) + @test length(AccumulatorTuple(lp_f64, ll_f64, np_i64)) == 3 + @test keys(at_all64) == (:LogPrior, :LogLikelihood, :NumProduce) + @test collect(at_all64) == [lp_f64, ll_f64, np_i64] + + # Replace the existing LogPriorAccumulator + @test setacc!!(at_all64, lp_f32)[:LogPrior] == lp_f32 + # Check that setacc!! didn't modify the original + @test at_all64 == AccumulatorTuple(lp_f64, ll_f64, np_i64) + # Add a new accumulator type. + @test setacc!!(AccumulatorTuple(lp_f64), ll_f64) == + AccumulatorTuple(lp_f64, ll_f64) + + @test getacc(at_all64, Val(:LogPrior)) == lp_f64 + end + + @testset "map_accumulator(s)!!" begin + # map over all accumulators + accs = AccumulatorTuple(lp_f32, ll_f32) + @test map(zero, accs) == AccumulatorTuple( + LogPriorAccumulator(0.0f0), LogLikelihoodAccumulator(0.0f0) + ) + # Test that the original wasn't modified. + @test accs == AccumulatorTuple(lp_f32, ll_f32) + + # A map with a closure that changes the types of the accumulators. + @test map(acc -> convert_eltype(Float64, acc), accs) == + AccumulatorTuple(LogPriorAccumulator(1.0), LogLikelihoodAccumulator(1.0)) + + # only apply to a particular accumulator + @test map_accumulator(zero, accs, Val(:LogLikelihood)) == + AccumulatorTuple(lp_f32, LogLikelihoodAccumulator(0.0f0)) + @test map_accumulator( + acc -> convert_eltype(Float64, acc), accs, Val(:LogLikelihood) + ) == AccumulatorTuple(lp_f32, LogLikelihoodAccumulator(1.0)) + end + end +end + +end diff --git a/test/compiler.jl b/test/compiler.jl index a0286d405..81c018111 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -189,12 +189,12 @@ module Issue537 end global model_ = __model__ global context_ = __context__ global rng_ = __context__.rng - global lp = getlogp(__varinfo__) + global lp = getlogjoint(__varinfo__) return x end model = testmodel_missing3([1.0]) varinfo = VarInfo(model) - @test getlogp(varinfo) == lp + @test getlogjoint(varinfo) == lp @test varinfo_ isa AbstractVarInfo @test model_ === model @test context_ isa SamplingContext @@ -208,13 +208,13 @@ module Issue537 end global model_ = __model__ global context_ = __context__ global rng_ = __context__.rng - global lp = getlogp(__varinfo__) + global lp = getlogjoint(__varinfo__) return x end false lpold = lp model = testmodel_missing4([1.0]) varinfo = VarInfo(model) - @test getlogp(varinfo) == lp == lpold + @test getlogjoint(varinfo) == lp == lpold # test DPPL#61 @model function testmodel_missing5(z) @@ -333,14 +333,14 @@ module Issue537 end function makemodel(p) @model function testmodel(x) x[1] ~ Bernoulli(p) - global lp = getlogp(__varinfo__) + global lp = getlogjoint(__varinfo__) return x end return testmodel end model = makemodel(0.5)([1.0]) varinfo = VarInfo(model) - @test getlogp(varinfo) == lp + @test getlogjoint(varinfo) == lp end @testset "user-defined variable name" begin @model f1() = x ~ NamedDist(Normal(), :y) @@ -364,9 +364,9 @@ module Issue537 end # TODO(torfjelde): We need conditioning for `Dict`. @test_broken f2_c() == 1 @test_broken f3_c() == 1 - @test_broken getlogp(VarInfo(f1_c)) == - getlogp(VarInfo(f2_c)) == - getlogp(VarInfo(f3_c)) + @test_broken getlogjoint(VarInfo(f1_c)) == + getlogjoint(VarInfo(f2_c)) == + getlogjoint(VarInfo(f3_c)) end @testset "custom tilde" begin @model demo() = begin diff --git a/test/context_implementations.jl b/test/context_implementations.jl index 0ec88c07c..ac6321d69 100644 --- a/test/context_implementations.jl +++ b/test/context_implementations.jl @@ -10,7 +10,7 @@ end end - test([1, 1, -1])(VarInfo(), SampleFromPrior(), LikelihoodContext()) + test([1, 1, -1])(VarInfo(), SampleFromPrior(), DefaultContext()) end @testset "dot tilde with varying sizes" begin @@ -18,13 +18,14 @@ @model function test(x, size) y = Array{Float64,length(size)}(undef, size...) y .~ Normal(x) - return y, getlogp(__varinfo__) + return y end for ysize in ((2,), (2, 3), (2, 3, 4)) x = randn() model = test(x, ysize) - y, lp = model() + y = model() + lp = logjoint(model, (; y=y)) @test lp ≈ sum(logpdf.(Normal.(x), y)) ys = [first(model()) for _ in 1:10_000] diff --git a/test/contexts.jl b/test/contexts.jl index 1ba099a37..5f22b75eb 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -9,7 +9,6 @@ using DynamicPPL: NodeTrait, IsLeaf, IsParent, - PointwiseLogdensityContext, contextual_isassumption, FixedContext, ConditionContext, @@ -47,18 +46,11 @@ Base.IteratorSize(::Type{<:AbstractContext}) = Base.SizeUnknown() Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @testset "contexts.jl" begin - child_contexts = Dict( + contexts = Dict( :default => DefaultContext(), - :prior => PriorContext(), - :likelihood => LikelihoodContext(), - ) - - parent_contexts = Dict( :testparent => DynamicPPL.TestUtils.TestParentContext(DefaultContext()), :sampling => SamplingContext(), - :minibatch => MiniBatchContext(DefaultContext(), 0.0), :prefix => PrefixContext(@varname(x)), - :pointwiselogdensity => PointwiseLogdensityContext(), :condition1 => ConditionContext((x=1.0,)), :condition2 => ConditionContext( (x=1.0,), DynamicPPL.TestUtils.TestParentContext(ConditionContext((y=2.0,))) @@ -70,8 +62,6 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() :condition4 => ConditionContext((x=[1.0, missing],)), ) - contexts = merge(child_contexts, parent_contexts) - @testset "$(name)" for (name, context) in contexts @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS DynamicPPL.TestUtils.test_context(context, model) @@ -235,7 +225,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() # Values from outer context should override inner one ctx1 = ConditionContext(n1, ConditionContext(n2)) @test ctx1.values == (x=1, y=2) - # Check that the two ConditionContexts are collapsed + # Check that the two ConditionContexts are collapsed @test childcontext(ctx1) isa DefaultContext # Then test the nesting the other way round ctx2 = ConditionContext(n2, ConditionContext(n1)) diff --git a/test/independence.jl b/test/independence.jl deleted file mode 100644 index a4a834a61..000000000 --- a/test/independence.jl +++ /dev/null @@ -1,11 +0,0 @@ -@testset "Turing independence" begin - @model coinflip(y) = begin - p ~ Beta(1, 1) - N = length(y) - for i in 1:N - y[i] ~ Bernoulli(p) - end - end - model = coinflip([1, 1, 0]) - model(SampleFromPrior(), LikelihoodContext()) -end diff --git a/test/linking.jl b/test/linking.jl index d424a9c2d..4f1707263 100644 --- a/test/linking.jl +++ b/test/linking.jl @@ -85,7 +85,7 @@ end DynamicPPL.link(vi, model) end # Difference should just be the log-absdet-jacobian "correction". - @test DynamicPPL.getlogp(vi) - DynamicPPL.getlogp(vi_linked) ≈ log(2) + @test DynamicPPL.getlogjoint(vi) - DynamicPPL.getlogjoint(vi_linked) ≈ log(2) @test vi_linked[@varname(m), dist] == LowerTriangular(vi[@varname(m), dist]) # Linked one should be working with a lower-dimensional representation. @test length(vi_linked[:]) < length(vi[:]) @@ -98,7 +98,7 @@ end end @test length(vi_invlinked[:]) == length(vi[:]) @test vi_invlinked[@varname(m), dist] ≈ LowerTriangular(vi[@varname(m), dist]) - @test DynamicPPL.getlogp(vi_invlinked) ≈ DynamicPPL.getlogp(vi) + @test DynamicPPL.getlogjoint(vi_invlinked) ≈ DynamicPPL.getlogjoint(vi) end end @@ -130,7 +130,7 @@ end end @test length(vi_linked[:]) == d * (d - 1) ÷ 2 # Should now include the log-absdet-jacobian correction. - @test !(getlogp(vi_linked) ≈ lp) + @test !(getlogjoint(vi_linked) ≈ lp) # Invlinked. vi_invlinked = if mutable DynamicPPL.invlink!!(deepcopy(vi_linked), model) @@ -138,7 +138,7 @@ end DynamicPPL.invlink(vi_linked, model) end @test length(vi_invlinked[:]) == d^2 - @test getlogp(vi_invlinked) ≈ lp + @test getlogjoint(vi_invlinked) ≈ lp end end end @@ -164,7 +164,7 @@ end end @test length(vi_linked[:]) == d - 1 # Should now include the log-absdet-jacobian correction. - @test !(getlogp(vi_linked) ≈ lp) + @test !(getlogjoint(vi_linked) ≈ lp) # Invlinked. vi_invlinked = if mutable DynamicPPL.invlink!!(deepcopy(vi_linked), model) @@ -172,7 +172,7 @@ end DynamicPPL.invlink(vi_linked, model) end @test length(vi_invlinked[:]) == d - @test getlogp(vi_invlinked) ≈ lp + @test getlogjoint(vi_invlinked) ≈ lp end end end diff --git a/test/model.jl b/test/model.jl index dd5a35fe6..6e4a24ae6 100644 --- a/test/model.jl +++ b/test/model.jl @@ -41,7 +41,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() m = vi[@varname(m)] # extract log pdf of variable object - lp = getlogp(vi) + lp = getlogjoint(vi) # log prior probability lprior = logprior(model, vi) @@ -494,7 +494,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() varinfo_linked_result = last( DynamicPPL.evaluate!!(model, deepcopy(varinfo_linked), DefaultContext()) ) - @test getlogp(varinfo_linked) ≈ getlogp(varinfo_linked_result) + @test getlogjoint(varinfo_linked) ≈ getlogjoint(varinfo_linked_result) end end diff --git a/test/pointwise_logdensities.jl b/test/pointwise_logdensities.jl index 61c842638..cfb222b66 100644 --- a/test/pointwise_logdensities.jl +++ b/test/pointwise_logdensities.jl @@ -1,6 +1,4 @@ @testset "logdensities_likelihoods.jl" begin - mod_ctx = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.2) - mod_ctx2 = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.4, mod_ctx) @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS example_values = DynamicPPL.TestUtils.rand_prior_true(model) @@ -37,11 +35,6 @@ lps = pointwise_logdensities(model, vi) logp = sum(sum, values(lps)) @test logp ≈ (logprior_true + loglikelihood_true) - - # Test that modifications of Setup are picked up - lps = pointwise_logdensities(model, vi, mod_ctx2) - logp = sum(sum, values(lps)) - @test logp ≈ (logprior_true + loglikelihood_true) * 1.2 * 1.4 end end diff --git a/test/runtests.jl b/test/runtests.jl index 72f33f2d0..4a9acf4e1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -49,13 +49,13 @@ include("test_util.jl") include("Aqua.jl") end include("utils.jl") + include("accumulators.jl") include("compiler.jl") include("varnamedvector.jl") include("varinfo.jl") include("simple_varinfo.jl") include("model.jl") include("sampler.jl") - include("independence.jl") include("distribution_wrappers.jl") include("logdensityfunction.jl") include("linking.jl") diff --git a/test/sampler.jl b/test/sampler.jl index 8c4f1ed96..fe9fd331a 100644 --- a/test/sampler.jl +++ b/test/sampler.jl @@ -84,7 +84,7 @@ let inits = (; p=0.2) chain = sample(model, sampler, 1; initial_params=inits, progress=false) @test chain[1].metadata.p.vals == [0.2] - @test getlogp(chain[1]) == lptrue + @test getlogjoint(chain[1]) == lptrue # parallel sampling chains = sample( @@ -98,7 +98,7 @@ ) for c in chains @test c[1].metadata.p.vals == [0.2] - @test getlogp(c[1]) == lptrue + @test getlogjoint(c[1]) == lptrue end end @@ -113,7 +113,7 @@ chain = sample(model, sampler, 1; initial_params=inits, progress=false) @test chain[1].metadata.s.vals == [4] @test chain[1].metadata.m.vals == [-1] - @test getlogp(chain[1]) == lptrue + @test getlogjoint(chain[1]) == lptrue # parallel sampling chains = sample( @@ -128,7 +128,7 @@ for c in chains @test c[1].metadata.s.vals == [4] @test c[1].metadata.m.vals == [-1] - @test getlogp(c[1]) == lptrue + @test getlogjoint(c[1]) == lptrue end end diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 380c24e7d..6f2f39a64 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -2,12 +2,12 @@ @testset "constructor & indexing" begin @testset "NamedTuple" begin svi = SimpleVarInfo(; m=1.0) - @test getlogp(svi) == 0.0 + @test getlogjoint(svi) == 0.0 @test haskey(svi, @varname(m)) @test !haskey(svi, @varname(m[1])) svi = SimpleVarInfo(; m=[1.0]) - @test getlogp(svi) == 0.0 + @test getlogjoint(svi) == 0.0 @test haskey(svi, @varname(m)) @test haskey(svi, @varname(m[1])) @test !haskey(svi, @varname(m[2])) @@ -21,20 +21,21 @@ @test !haskey(svi, @varname(m.a.b)) svi = SimpleVarInfo{Float32}(; m=1.0) - @test getlogp(svi) isa Float32 + @test getlogjoint(svi) isa Float32 - svi = SimpleVarInfo((m=1.0,), 1.0) - @test getlogp(svi) == 1.0 + svi = SimpleVarInfo((m=1.0,)) + svi = accloglikelihood!!(svi, 1.0) + @test getlogjoint(svi) == 1.0 end @testset "Dict" begin svi = SimpleVarInfo(Dict(@varname(m) => 1.0)) - @test getlogp(svi) == 0.0 + @test getlogjoint(svi) == 0.0 @test haskey(svi, @varname(m)) @test !haskey(svi, @varname(m[1])) svi = SimpleVarInfo(Dict(@varname(m) => [1.0])) - @test getlogp(svi) == 0.0 + @test getlogjoint(svi) == 0.0 @test haskey(svi, @varname(m)) @test haskey(svi, @varname(m[1])) @test !haskey(svi, @varname(m[2])) @@ -59,12 +60,12 @@ @testset "VarNamedVector" begin svi = SimpleVarInfo(push!!(DynamicPPL.VarNamedVector(), @varname(m) => 1.0)) - @test getlogp(svi) == 0.0 + @test getlogjoint(svi) == 0.0 @test haskey(svi, @varname(m)) @test !haskey(svi, @varname(m[1])) svi = SimpleVarInfo(push!!(DynamicPPL.VarNamedVector(), @varname(m) => [1.0])) - @test getlogp(svi) == 0.0 + @test getlogjoint(svi) == 0.0 @test haskey(svi, @varname(m)) @test haskey(svi, @varname(m[1])) @test !haskey(svi, @varname(m[2])) @@ -98,11 +99,10 @@ vi = DynamicPPL.setindex!!(vi, get(values_constrained, vn), vn) end vi = last(DynamicPPL.evaluate!!(model, vi, DefaultContext())) - lp_orig = getlogp(vi) # `link!!` vi_linked = link!!(deepcopy(vi), model) - lp_linked = getlogp(vi_linked) + lp_linked = getlogjoint(vi_linked) values_unconstrained, lp_linked_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( model, values_constrained... ) @@ -113,7 +113,7 @@ # `invlink!!` vi_invlinked = invlink!!(deepcopy(vi_linked), model) - lp_invlinked = getlogp(vi_invlinked) + lp_invlinked = getlogjoint(vi_invlinked) lp_invlinked_true = DynamicPPL.TestUtils.logjoint_true( model, values_constrained... ) @@ -152,7 +152,7 @@ # DynamicPPL.settrans!!(deepcopy(svi_dict), true), # DynamicPPL.settrans!!(deepcopy(svi_vnv), true), ) - # RandOM seed is set in each `@testset`, so we need to sample + # Random seed is set in each `@testset`, so we need to sample # a new realization for `m` here. retval = model() @@ -166,7 +166,7 @@ end # Logjoint should be non-zero wp. 1. - @test getlogp(svi_new) != 0 + @test getlogjoint(svi_new) != 0 ### Evaluation ### values_eval_constrained = DynamicPPL.TestUtils.rand_prior_true(model) @@ -201,7 +201,7 @@ svi_eval = DynamicPPL.setindex!!(svi_eval, get(values_eval, vn), vn) end - # Reset the logp field. + # Reset the logp accumulators. svi_eval = DynamicPPL.resetlogp!!(svi_eval) # Compute `logjoint` using the varinfo. @@ -250,7 +250,7 @@ end # `getlogp` should be equal to the logjoint with log-absdet-jac correction. - lp = getlogp(svi) + lp = getlogjoint(svi) # needs higher atol because of https://github.com/TuringLang/Bijectors.jl/issues/375 @test lp ≈ lp_true atol = 1.2e-5 end @@ -306,7 +306,7 @@ DynamicPPL.tovec(retval_unconstrained.m) # The resulting varinfo should hold the correct logp. - lp = getlogp(vi_linked_result) + lp = getlogjoint(vi_linked_result) @test lp ≈ lp_true end end diff --git a/test/submodels.jl b/test/submodels.jl index e79eed2c3..d3a2f17e7 100644 --- a/test/submodels.jl +++ b/test/submodels.jl @@ -35,7 +35,7 @@ using Test @test model()[1] == x_val # Test that the logp was correctly set vi = VarInfo(model) - @test getlogp(vi) == x_logp + logpdf(Normal(), vi[@varname(a.y)]) + @test getlogjoint(vi) == x_logp + logpdf(Normal(), vi[@varname(a.y)]) # Check the keys @test Set(keys(VarInfo(model))) == Set([@varname(a.y)]) end @@ -67,7 +67,7 @@ using Test @test model()[1] == x_val # Test that the logp was correctly set vi = VarInfo(model) - @test getlogp(vi) == x_logp + logpdf(Normal(), vi[@varname(y)]) + @test getlogjoint(vi) == x_logp + logpdf(Normal(), vi[@varname(y)]) # Check the keys @test Set(keys(VarInfo(model))) == Set([@varname(y)]) end @@ -99,7 +99,7 @@ using Test @test model()[1] == x_val # Test that the logp was correctly set vi = VarInfo(model) - @test getlogp(vi) == x_logp + logpdf(Normal(), vi[@varname(b.y)]) + @test getlogjoint(vi) == x_logp + logpdf(Normal(), vi[@varname(b.y)]) # Check the keys @test Set(keys(VarInfo(model))) == Set([@varname(b.y)]) end @@ -148,7 +148,7 @@ using Test # No conditioning vi = VarInfo(h()) @test Set(keys(vi)) == Set([@varname(a.b.x), @varname(a.b.y)]) - @test getlogp(vi) == + @test getlogjoint(vi) == logpdf(Normal(), vi[@varname(a.b.x)]) + logpdf(Normal(), vi[@varname(a.b.y)]) @@ -174,7 +174,7 @@ using Test @testset "$name" for (name, model) in models vi = VarInfo(model) @test Set(keys(vi)) == Set([@varname(a.b.y)]) - @test getlogp(vi) == x_logp + logpdf(Normal(), vi[@varname(a.b.y)]) + @test getlogjoint(vi) == x_logp + logpdf(Normal(), vi[@varname(a.b.y)]) end end end diff --git a/test/threadsafe.jl b/test/threadsafe.jl index 72c439db8..5b4f6951f 100644 --- a/test/threadsafe.jl +++ b/test/threadsafe.jl @@ -4,9 +4,12 @@ threadsafe_vi = @inferred DynamicPPL.ThreadSafeVarInfo(vi) @test threadsafe_vi.varinfo === vi - @test threadsafe_vi.logps isa Vector{typeof(Ref(getlogp(vi)))} - @test length(threadsafe_vi.logps) == Threads.nthreads() - @test all(iszero(x[]) for x in threadsafe_vi.logps) + @test threadsafe_vi.accs_by_thread isa Vector{<:DynamicPPL.AccumulatorTuple} + @test length(threadsafe_vi.accs_by_thread) == Threads.nthreads() + expected_accs = DynamicPPL.AccumulatorTuple( + (DynamicPPL.split(acc) for acc in DynamicPPL.getaccs(vi))... + ) + @test all(accs == expected_accs for accs in threadsafe_vi.accs_by_thread) end # TODO: Add more tests of the public API @@ -14,23 +17,27 @@ vi = VarInfo(gdemo_default) threadsafe_vi = DynamicPPL.ThreadSafeVarInfo(vi) - lp = getlogp(vi) - @test getlogp(threadsafe_vi) == lp + lp = getlogjoint(vi) + @test getlogjoint(threadsafe_vi) == lp - acclogp!!(threadsafe_vi, 42) - @test threadsafe_vi.logps[Threads.threadid()][] == 42 - @test getlogp(vi) == lp - @test getlogp(threadsafe_vi) == lp + 42 + threadsafe_vi = DynamicPPL.acclogprior!!(threadsafe_vi, 42) + @test threadsafe_vi.accs_by_thread[Threads.threadid()][:LogPrior].logp == 42 + @test getlogjoint(vi) == lp + @test getlogjoint(threadsafe_vi) == lp + 42 - resetlogp!!(threadsafe_vi) - @test iszero(getlogp(vi)) - @test iszero(getlogp(threadsafe_vi)) - @test all(iszero(x[]) for x in threadsafe_vi.logps) + threadsafe_vi = resetlogp!!(threadsafe_vi) + @test iszero(getlogjoint(threadsafe_vi)) + expected_accs = DynamicPPL.AccumulatorTuple( + (DynamicPPL.split(acc) for acc in DynamicPPL.getaccs(threadsafe_vi.varinfo))... + ) + @test all(accs == expected_accs for accs in threadsafe_vi.accs_by_thread) - setlogp!!(threadsafe_vi, 42) - @test getlogp(vi) == 42 - @test getlogp(threadsafe_vi) == 42 - @test all(iszero(x[]) for x in threadsafe_vi.logps) + threadsafe_vi = setlogprior!!(threadsafe_vi, 42) + @test getlogjoint(threadsafe_vi) == 42 + expected_accs = DynamicPPL.AccumulatorTuple( + (DynamicPPL.split(acc) for acc in DynamicPPL.getaccs(threadsafe_vi.varinfo))... + ) + @test all(accs == expected_accs for accs in threadsafe_vi.accs_by_thread) end @testset "model" begin @@ -48,7 +55,7 @@ vi = VarInfo() wthreads(x)(vi) - lp_w_threads = getlogp(vi) + lp_w_threads = getlogjoint(vi) if Threads.nthreads() == 1 @test vi_ isa VarInfo else @@ -65,7 +72,7 @@ vi, SamplingContext(Random.default_rng(), SampleFromPrior(), DefaultContext()), ) - @test getlogp(vi) ≈ lp_w_threads + @test getlogjoint(vi) ≈ lp_w_threads @test vi_ isa DynamicPPL.ThreadSafeVarInfo println(" evaluate_threadsafe!!:") @@ -85,7 +92,7 @@ vi = VarInfo() wothreads(x)(vi) - lp_wo_threads = getlogp(vi) + lp_wo_threads = getlogjoint(vi) if Threads.nthreads() == 1 @test vi_ isa VarInfo else @@ -104,7 +111,7 @@ vi, SamplingContext(Random.default_rng(), SampleFromPrior(), DefaultContext()), ) - @test getlogp(vi) ≈ lp_w_threads + @test getlogjoint(vi) ≈ lp_w_threads @test vi_ isa VarInfo println(" evaluate_threadunsafe!!:") diff --git a/test/utils.jl b/test/utils.jl index d683f132d..b85d21c41 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -1,15 +1,61 @@ @testset "utils.jl" begin @testset "addlogprob!" begin @model function testmodel() - global lp_before = getlogp(__varinfo__) + global lp_before = getlogjoint(__varinfo__) @addlogprob!(42) - return global lp_after = getlogp(__varinfo__) + return global lp_after = getlogjoint(__varinfo__) end - model = testmodel() - varinfo = VarInfo(model) + varinfo = VarInfo(testmodel()) @test iszero(lp_before) - @test getlogp(varinfo) == lp_after == 42 + @test getlogjoint(varinfo) == lp_after == 42 + @test getloglikelihood(varinfo) == 42 + + @model function testmodel_nt() + global lp_before = getlogjoint(__varinfo__) + @addlogprob! (; logprior=(pi + 1), loglikelihood=42) + return global lp_after = getlogjoint(__varinfo__) + end + + varinfo = VarInfo(testmodel_nt()) + @test iszero(lp_before) + @test getlogjoint(varinfo) == lp_after == 42 + 1 + pi + @test getloglikelihood(varinfo) == 42 + @test getlogprior(varinfo) == pi + 1 + + @model function testmodel_nt2() + global lp_before = getlogjoint(__varinfo__) + llh_nt = (; loglikelihood=42) + @addlogprob! llh_nt + return global lp_after = getlogjoint(__varinfo__) + end + + varinfo = VarInfo(testmodel_nt2()) + @test iszero(lp_before) + @test getlogjoint(varinfo) == lp_after == 42 + @test getloglikelihood(varinfo) == 42 + + @model function testmodel_likelihood() + global lp_before = getlogjoint(__varinfo__) + @addloglikelihood! 42 + return global lp_after = getlogjoint(__varinfo__) + end + + varinfo = VarInfo(testmodel_likelihood()) + @test iszero(lp_before) + @test getlogjoint(varinfo) == lp_after == 42 + @test getloglikelihood(varinfo) == 42 + + @model function testmodel_prior() + global lp_before = getlogjoint(__varinfo__) + @addlogprior! 42 + return global lp_after = getlogjoint(__varinfo__) + end + + varinfo = VarInfo(testmodel_prior()) + @test iszero(lp_before) + @test getlogjoint(varinfo) == lp_after == 42 + @test getlogprior(varinfo) == 42 end @testset "getargs_dottilde" begin diff --git a/test/varinfo.jl b/test/varinfo.jl index 777917aa6..efa8c6e4c 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -80,7 +80,7 @@ end function test_base!!(vi_original) vi = empty!!(vi_original) - @test getlogp(vi) == 0 + @test getlogjoint(vi) == 0 @test isempty(vi[:]) vn = @varname x @@ -123,13 +123,25 @@ end @testset "get/set/acc/resetlogp" begin function test_varinfo_logp!(vi) - @test DynamicPPL.getlogp(vi) === 0.0 - vi = DynamicPPL.setlogp!!(vi, 1.0) - @test DynamicPPL.getlogp(vi) === 1.0 - vi = DynamicPPL.acclogp!!(vi, 1.0) - @test DynamicPPL.getlogp(vi) === 2.0 + @test DynamicPPL.getlogjoint(vi) === 0.0 + vi = DynamicPPL.setlogprior!!(vi, 1.0) + @test DynamicPPL.getlogprior(vi) === 1.0 + @test DynamicPPL.getloglikelihood(vi) === 0.0 + @test DynamicPPL.getlogjoint(vi) === 1.0 + vi = DynamicPPL.acclogprior!!(vi, 1.0) + @test DynamicPPL.getlogprior(vi) === 2.0 + @test DynamicPPL.getloglikelihood(vi) === 0.0 + @test DynamicPPL.getlogjoint(vi) === 2.0 + vi = DynamicPPL.setloglikelihood!!(vi, 1.0) + @test DynamicPPL.getlogprior(vi) === 2.0 + @test DynamicPPL.getloglikelihood(vi) === 1.0 + @test DynamicPPL.getlogjoint(vi) === 3.0 + vi = DynamicPPL.accloglikelihood!!(vi, 1.0) + @test DynamicPPL.getlogprior(vi) === 2.0 + @test DynamicPPL.getloglikelihood(vi) === 2.0 + @test DynamicPPL.getlogjoint(vi) === 4.0 vi = DynamicPPL.resetlogp!!(vi) - @test DynamicPPL.getlogp(vi) === 0.0 + @test DynamicPPL.getlogjoint(vi) === 0.0 end vi = VarInfo() @@ -140,6 +152,98 @@ end test_varinfo_logp!(SimpleVarInfo(DynamicPPL.VarNamedVector())) end + @testset "accumulators" begin + @model function demo() + a ~ Normal() + b ~ Normal() + c ~ Normal() + d ~ Normal() + return nothing + end + + values = (; a=1.0, b=2.0, c=3.0, d=4.0) + lp_a = logpdf(Normal(), values.a) + lp_b = logpdf(Normal(), values.b) + lp_c = logpdf(Normal(), values.c) + lp_d = logpdf(Normal(), values.d) + m = demo() | (; c=values.c, d=values.d) + + vi = DynamicPPL.reset_num_produce!!( + DynamicPPL.unflatten(VarInfo(m), collect(values)) + ) + + vi = last(DynamicPPL.evaluate!!(m, deepcopy(vi))) + @test getlogprior(vi) == lp_a + lp_b + @test getloglikelihood(vi) == lp_c + lp_d + @test getlogp(vi) == (; logprior=lp_a + lp_b, loglikelihood=lp_c + lp_d) + @test getlogjoint(vi) == lp_a + lp_b + lp_c + lp_d + @test get_num_produce(vi) == 2 + @test begin + vi = acclogprior!!(vi, 1.0) + getlogprior(vi) == lp_a + lp_b + 1.0 + end + @test begin + vi = accloglikelihood!!(vi, 1.0) + getloglikelihood(vi) == lp_c + lp_d + 1.0 + end + @test begin + vi = setlogprior!!(vi, -1.0) + getlogprior(vi) == -1.0 + end + @test begin + vi = setloglikelihood!!(vi, -1.0) + getloglikelihood(vi) == -1.0 + end + @test begin + vi = setlogp!!(vi, (logprior=-3.0, loglikelihood=-3.0)) + getlogp(vi) == (; logprior=-3.0, loglikelihood=-3.0) + end + @test begin + vi = acclogp!!(vi, (logprior=1.0, loglikelihood=1.0)) + getlogp(vi) == (; logprior=-2.0, loglikelihood=-2.0) + end + @test getlogp(setlogp!!(vi, getlogp(vi))) == getlogp(vi) + + vi = last( + DynamicPPL.evaluate!!( + m, DynamicPPL.setaccs!!(deepcopy(vi), (LogPriorAccumulator(),)) + ), + ) + @test getlogprior(vi) == lp_a + lp_b + @test_throws "has no field LogLikelihoodAccumulator" getloglikelihood(vi) + @test_throws "has no field LogLikelihoodAccumulator" getlogp(vi) + @test_throws "has no field LogLikelihoodAccumulator" getlogjoint(vi) + @test_throws "has no field NumProduceAccumulator" get_num_produce(vi) + @test begin + vi = acclogprior!!(vi, 1.0) + getlogprior(vi) == lp_a + lp_b + 1.0 + end + @test begin + vi = setlogprior!!(vi, -1.0) + getlogprior(vi) == -1.0 + end + + vi = last( + DynamicPPL.evaluate!!( + m, DynamicPPL.setaccs!!(deepcopy(vi), (NumProduceAccumulator(),)) + ), + ) + @test_throws "has no field LogPriorAccumulator" getlogprior(vi) + @test_throws "has no field LogLikelihoodAccumulator" getloglikelihood(vi) + @test_throws "has no field LogPriorAccumulator" getlogp(vi) + @test_throws "has no field LogPriorAccumulator" getlogjoint(vi) + @test get_num_produce(vi) == 2 + + # Test evaluating without any accumulators. + vi = last(DynamicPPL.evaluate!!(m, DynamicPPL.setaccs!!(deepcopy(vi), ()))) + @test_throws "has no field LogPriorAccumulator" getlogprior(vi) + @test_throws "has no field LogLikelihoodAccumulator" getloglikelihood(vi) + @test_throws "has no field LogPriorAccumulator" getlogp(vi) + @test_throws "has no field LogPriorAccumulator" getlogjoint(vi) + @test_throws "has no field NumProduceAccumulator" get_num_produce(vi) + @test_throws "has no field NumProduceAccumulator" reset_num_produce!!(vi) + end + @testset "flags" begin # Test flag setting: # is_flagged, set_flag!, unset_flag! @@ -455,12 +559,24 @@ end ## `untyped_varinfo` vi = DynamicPPL.untyped_varinfo(model) + + ## `untyped_varinfo` + vi = DynamicPPL.untyped_varinfo(model) + vi = DynamicPPL.settrans!!(vi, true, vn) + # Sample in unconstrained space. + vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) + f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) + x = f(DynamicPPL.getindex_internal(vi, vn)) + @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + + ## `typed_varinfo` + vi = DynamicPPL.typed_varinfo(model) vi = DynamicPPL.settrans!!(vi, true, vn) # Sample in unconstrained space. vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) x = f(DynamicPPL.getindex_internal(vi, vn)) - @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) ## `typed_varinfo` vi = DynamicPPL.typed_varinfo(model) @@ -469,7 +585,7 @@ end vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) x = f(DynamicPPL.getindex_internal(vi, vn)) - @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) ### `SimpleVarInfo` ## `SimpleVarInfo{<:NamedTuple}` @@ -478,7 +594,7 @@ end vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) x = f(DynamicPPL.getindex_internal(vi, vn)) - @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) ## `SimpleVarInfo{<:Dict}` vi = DynamicPPL.settrans!!(SimpleVarInfo(Dict()), true) @@ -486,7 +602,7 @@ end vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) x = f(DynamicPPL.getindex_internal(vi, vn)) - @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) ## `SimpleVarInfo{<:VarNamedVector}` vi = DynamicPPL.settrans!!(SimpleVarInfo(DynamicPPL.VarNamedVector()), true) @@ -494,7 +610,7 @@ end vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) x = f(DynamicPPL.getindex_internal(vi, vn)) - @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) end @testset "values_as" begin @@ -596,8 +712,8 @@ end lp = logjoint(model, varinfo) @test lp ≈ lp_true - @test getlogp(varinfo) ≈ lp_true - lp_linked = getlogp(varinfo_linked) + @test getlogjoint(varinfo) ≈ lp_true + lp_linked = getlogjoint(varinfo_linked) @test lp_linked ≈ lp_linked_true # TODO: Compare values once we are no longer working with `NamedTuple` for @@ -609,7 +725,7 @@ end varinfo_linked_unflattened, model ) @test length(varinfo_invlinked[:]) == length(varinfo[:]) - @test getlogp(varinfo_invlinked) ≈ lp_true + @test getlogjoint(varinfo_invlinked) ≈ lp_true end end end @@ -941,19 +1057,19 @@ end # First iteration, variables are added to vi # variables samples in order: z1,a1,z2,a2,z3 - DynamicPPL.increment_num_produce!(vi) + vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_z1, dists[1]) randr(vi, vn_a1, dists[2]) - DynamicPPL.increment_num_produce!(vi) + vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_b, dists[2]) randr(vi, vn_z2, dists[1]) randr(vi, vn_a2, dists[2]) - DynamicPPL.increment_num_produce!(vi) + vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_z3, dists[1]) @test vi.metadata.orders == [1, 1, 2, 2, 2, 3] @test DynamicPPL.get_num_produce(vi) == 3 - DynamicPPL.reset_num_produce!(vi) + vi = DynamicPPL.reset_num_produce!!(vi) DynamicPPL.set_retained_vns_del!(vi) @test DynamicPPL.is_flagged(vi, vn_z1, "del") @test DynamicPPL.is_flagged(vi, vn_a1, "del") @@ -961,12 +1077,12 @@ end @test DynamicPPL.is_flagged(vi, vn_a2, "del") @test DynamicPPL.is_flagged(vi, vn_z3, "del") - DynamicPPL.increment_num_produce!(vi) + vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_z1, dists[1]) randr(vi, vn_a1, dists[2]) - DynamicPPL.increment_num_produce!(vi) + vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_z2, dists[1]) - DynamicPPL.increment_num_produce!(vi) + vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_z3, dists[1]) randr(vi, vn_a2, dists[2]) @test vi.metadata.orders == [1, 1, 2, 2, 3, 3] @@ -975,21 +1091,21 @@ end vi = empty!!(DynamicPPL.typed_varinfo(vi)) # First iteration, variables are added to vi # variables samples in order: z1,a1,z2,a2,z3 - DynamicPPL.increment_num_produce!(vi) + vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_z1, dists[1]) randr(vi, vn_a1, dists[2]) - DynamicPPL.increment_num_produce!(vi) + vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_b, dists[2]) randr(vi, vn_z2, dists[1]) randr(vi, vn_a2, dists[2]) - DynamicPPL.increment_num_produce!(vi) + vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_z3, dists[1]) @test vi.metadata.z.orders == [1, 2, 3] @test vi.metadata.a.orders == [1, 2] @test vi.metadata.b.orders == [2] @test DynamicPPL.get_num_produce(vi) == 3 - DynamicPPL.reset_num_produce!(vi) + vi = DynamicPPL.reset_num_produce!!(vi) DynamicPPL.set_retained_vns_del!(vi) @test DynamicPPL.is_flagged(vi, vn_z1, "del") @test DynamicPPL.is_flagged(vi, vn_a1, "del") @@ -997,12 +1113,12 @@ end @test DynamicPPL.is_flagged(vi, vn_a2, "del") @test DynamicPPL.is_flagged(vi, vn_z3, "del") - DynamicPPL.increment_num_produce!(vi) + vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_z1, dists[1]) randr(vi, vn_a1, dists[2]) - DynamicPPL.increment_num_produce!(vi) + vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_z2, dists[1]) - DynamicPPL.increment_num_produce!(vi) + vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_z3, dists[1]) randr(vi, vn_a2, dists[2]) @test vi.metadata.z.orders == [1, 2, 3] @@ -1017,8 +1133,8 @@ end n = length(varinfo[:]) # `Bool`. - @test getlogp(DynamicPPL.unflatten(varinfo, fill(true, n))) isa typeof(float(1)) + @test getlogjoint(DynamicPPL.unflatten(varinfo, fill(true, n))) isa typeof(float(1)) # `Int`. - @test getlogp(DynamicPPL.unflatten(varinfo, fill(1, n))) isa typeof(float(1)) + @test getlogjoint(DynamicPPL.unflatten(varinfo, fill(1, n))) isa typeof(float(1)) end end diff --git a/test/varnamedvector.jl b/test/varnamedvector.jl index bd3f5553f..f21d458a8 100644 --- a/test/varnamedvector.jl +++ b/test/varnamedvector.jl @@ -607,7 +607,7 @@ end DynamicPPL.evaluate!!(model, deepcopy(varinfo), DefaultContext()) ) # Log density should be the same. - @test getlogp(varinfo_eval) ≈ logp_true + @test getlogjoint(varinfo_eval) ≈ logp_true # Values should be the same. DynamicPPL.TestUtils.test_values(varinfo_eval, value_true, vns) @@ -616,7 +616,7 @@ end DynamicPPL.evaluate!!(model, deepcopy(varinfo), SamplingContext()) ) # Log density should be different. - @test getlogp(varinfo_sample) != getlogp(varinfo) + @test getlogjoint(varinfo_sample) != getlogjoint(varinfo) # Values should be different. DynamicPPL.TestUtils.test_values( varinfo_sample, value_true, vns; compare=!isequal