diff --git a/HISTORY.md b/HISTORY.md index 6d237e698..a075e5a3b 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -76,7 +76,10 @@ Please see the docstring for details. Each initialisation strategy can decide what kind of `AbstractTransformedValue` to return. This has no impact on whether the log-Jacobian is calculated or not, as that is determined by the *transform strategy* (see below). -### Transform strategies +### `init!!` and transform strategies + +The initialisation strategy argument to `init!!` used to default to `InitFromPrior()`. +It is now mandatory to specify this explicitly. When using `InitContext`, you can (and indeed sometimes must) now specify a *transform strategy* which controls whether values are interpreted as being in transformed space or not. This in turn controls whether: @@ -107,7 +110,7 @@ In its place, you should directly use the accumulator API to: To do so, we now export a convenience function `get_raw_values(::AbstractVarInfo)` that will get the stored `VarNamedTuple` of raw values. This is exactly analogous to how `getlogprior(::AbstractVarInfo)` extracts the log-prior from a `LogPriorAccumulator`. -### Function signature changes +### Function signature changes in tilde-pipeline `tilde_assume!!` and `accumulate_assume!!` now take extra arguments. @@ -122,6 +125,12 @@ In particular `tval` is either the `AbstractTransformedValue` that `DynamicPPL.init` provided (for InitContext), or the `AbstractTransformedValue` found inside the VarInfo (for DefaultContext). - `accumulate_assume!!(vi::AbstractVarInfo, val, logjac, vn, dist)` is now `accumulate_assume!!(vi, val, tval, logjac, vn, dist, template)`. +### `DynamicPPL.DebugUtils` + +The signature of `DynamicPPL.DebugUtils.check_model` and `DynamicPPL.DebugUtils.check_model_and_trace` are now changed. +Instead of taking a `VarInfo` as the second argument, they now do not need a `VarInfo` at all; they simply sample from the prior of the model. +To make this reproducible you can optionally pass `rng` as a first argument (before the model). + ### Overhaul of `VarInfo` DynamicPPL tracks variable values during model execution using one of the `AbstractVarInfo` types. @@ -214,6 +223,19 @@ For example, carrying on from the above, `conditioned(f() | vnt)` will return `v The underlying code for `ConditionContext` and `FixedContext` is almost completely the same. In this release, to reduce code duplication, they have been merged into a single implementation, `CondFixContext{Condition}` and `CondFixContext{Fix}`, where the type parameter controls whether conditioning or fixing is performed. +### `DynamicPPL.evaluate!!(model, varinfo)` now warns + +This method has very complicated semantics; it's difficult to use properly. +In DynamicPPL we are moving away from trying to encode all the different ways of evaluating a model in the `varinfo` object, and in a future release of DynamicPPL this method will be removed entirely. + +For now, the method still exists, but we would like to strongly encourage users to avoid using this method. +In place you should use `init!!([rng,] model, oavi::OnlyAccsVarInfo, init_strategy, transform_strategy)` instead, which is much more explicit, and more closely matches what DynamicPPL.jl will use exclusively in the future. + +If you are using this function and are unsure how to adapt your code, please: + + 1. Read the documentation! There is a *lot* more documentation at https://turinglang.org/DynamicPPL.jl/v0.40/. + 2. If you can't figure it out, please open an issue. We are happy to help. + ### Accumulator interface exports more functions To define your own accumulator, you have to overload a number of functions. diff --git a/benchmarks/src/DynamicPPLBenchmarks.jl b/benchmarks/src/DynamicPPLBenchmarks.jl index 6bb8672c9..04eba4b46 100644 --- a/benchmarks/src/DynamicPPLBenchmarks.jl +++ b/benchmarks/src/DynamicPPLBenchmarks.jl @@ -22,11 +22,12 @@ export Models, benchmark, model_dimension Return the dimension of `model`, accounting for linking, if any. """ function model_dimension(model, islinked) - vi = VarInfo() - vi = last(DynamicPPL.init!!(StableRNG(23), model, vi)) - if islinked - vi = DynamicPPL.link(vi, model) - end + tfm_strategy = islinked ? DynamicPPL.LinkAll() : DynamicPPL.UnlinkAll() + vi = last( + DynamicPPL.init!!( + StableRNG(23), model, VarInfo(), DynamicPPL.InitFromPrior(), tfm_strategy + ), + ) return length(vi[:]) end diff --git a/docs/src/accumulators.md b/docs/src/accumulators.md index 9f96e825d..efa0b2043 100644 --- a/docs/src/accumulators.md +++ b/docs/src/accumulators.md @@ -227,7 +227,7 @@ Because the accumulation process is not always commutative, you may in general e However, for many accumulators such as log-probability accumulators, this is not an issue. We can see this in action if we step through the internal DynamicPPL calls. -(Note that calling `DynamicPPL.evaluate!!` on a model where thread-safe mode has been enabled will automatically perform these steps for you.) +(Note that calling `DynamicPPL.init!!` on a model where thread-safe mode has been enabled will automatically perform these steps for you.) ```@example 1 Threads.nthreads() diff --git a/ext/DynamicPPLMarginalLogDensitiesExt.jl b/ext/DynamicPPLMarginalLogDensitiesExt.jl index 8e53d8709..348152d30 100644 --- a/ext/DynamicPPLMarginalLogDensitiesExt.jl +++ b/ext/DynamicPPLMarginalLogDensitiesExt.jl @@ -143,9 +143,21 @@ VarInfo used in the marginalisation. !!! note The other fields of the VarInfo, e.g. accumulated log-probabilities, will not be - updated. If you wish to have a fully consistent VarInfo, you should re-evaluate the - model with the returned VarInfo (e.g. using `vi = last(DynamicPPL.evaluate!!(model, - vi))`). + updated. If you wish to obtain updated log-probabilities, you should re-evaluate the + model with the values inside the returned VarInfo, for example using: + + ```julia + init_strategy = DynamicPPL.InitFromParams(varinfo.values, nothing) + oavi = DynamicPPL.OnlyAccsVarInfo(( + DynamicPPL.LogPriorAccumulator(), + DynamicPPL.LogLikelihoodAccumulator(), + DynamicPPL.RawValueAccumulator(false), + # ... whatever else you need + )) + _, oavi = DynamicPPL.init!!(rng, model, oavi, init_strategy, DynamicPPL.UnlinkAll()) + ``` + + You can then extract all the updated data from `oavi`. ## Example diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 14f43aaef..447aaad43 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -289,21 +289,6 @@ if isdefined(Base.Experimental, :register_error_hint) ) end end - - Base.Experimental.register_error_hint(MethodError) do io, exc, argtypes, _ - is_evaluate_three_arg = - exc.f === AbstractPPL.evaluate!! && - length(argtypes) == 3 && - argtypes[1] <: Model && - argtypes[2] <: AbstractVarInfo && - argtypes[3] <: AbstractContext - if is_evaluate_three_arg - print( - io, - "\n\nThe method `evaluate!!(model, varinfo, new_ctx)` has been removed. Instead, you should store the `new_ctx` in the `model.context` field using `new_model = contextualize(model, new_ctx)`, and then call `evaluate!!(new_model, varinfo)` on the new model. (Note that, if the model already contained a non-default context, you will need to wrap the existing context.)", - ) - end - end end end diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index c5cccfd7e..b1b07f35d 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -123,6 +123,15 @@ function setaccs!!(vi::AbstractVarInfo, accs::NTuple{N,AbstractAccumulator}) whe return setaccs!!(vi, AccumulatorTuple(accs)) end +""" + get_values(vi::AbstractVarInfo) + +Return the `VarNamedTuple` in `vi` that stores the variables' values. + +This should be implemented by each subtype of `AbstractVarInfo`. +""" +function get_values end + """ getaccs(vi::AbstractVarInfo) diff --git a/src/accumulators/pointwise_logdensities.jl b/src/accumulators/pointwise_logdensities.jl index d079a3ce8..9e4be7c6a 100644 --- a/src/accumulators/pointwise_logdensities.jl +++ b/src/accumulators/pointwise_logdensities.jl @@ -74,9 +74,10 @@ function pointwise_logdensities( model::Model, varinfo::AbstractVarInfo, ::Val{whichlogprob}=Val(:both) ) where {whichlogprob} AccType = PointwiseLogProbAccumulator{whichlogprob} - varinfo = setaccs!!(varinfo, (AccType(),)) - varinfo = last(evaluate!!(model, varinfo)) - return getacc(varinfo, Val(accumulator_name(AccType))).logps + oavi = OnlyAccsVarInfo((AccType(),)) + init_strategy = InitFromParams(varinfo.values, nothing) + oavi = last(init!!(model, oavi, init_strategy, UnlinkAll())) + return getacc(oavi, Val(accumulator_name(AccType))).logps end function pointwise_loglikelihoods(model::Model, varinfo::AbstractVarInfo) diff --git a/src/accumulators/priors.jl b/src/accumulators/priors.jl index 1d4c22e0b..8f311edc9 100644 --- a/src/accumulators/priors.jl +++ b/src/accumulators/priors.jl @@ -87,7 +87,8 @@ This is done by evaluating the model at the values present in `varinfo` and recording the distributions that are present at each tilde statement. """ function extract_priors(model::Model, varinfo::AbstractVarInfo) - varinfo = setaccs!!(deepcopy(varinfo), (PriorDistributionAccumulator(),)) - varinfo = last(evaluate!!(model, varinfo)) + oavi = OnlyAccsVarInfo((PriorDistributionAccumulator(),)) + init_strategy = InitFromParams(varinfo.values, nothing) + varinfo = last(init!!(model, oavi, init_strategy, UnlinkAll())) return getacc(varinfo, Val(PRIOR_ACCNAME)).values end diff --git a/src/chains.jl b/src/chains.jl index e80d8b361..9de0fc15c 100644 --- a/src/chains.jl +++ b/src/chains.jl @@ -47,16 +47,17 @@ function ParamsWithStats( else (DynamicPPL.RawValueAccumulator(include_colon_eq),) end - varinfo = DynamicPPL.setaccs!!(varinfo, accs) - varinfo = last(DynamicPPL.evaluate!!(model, varinfo)) - params = get_raw_values(varinfo) + oavi = OnlyAccsVarInfo(accs) + init = InitFromParams(varinfo.values, nothing) + oavi = last(DynamicPPL.init!!(model, oavi, init, UnlinkAll())) + params = get_raw_values(oavi) if include_log_probs stats = merge( stats, ( - logprior=DynamicPPL.getlogprior(varinfo), - loglikelihood=DynamicPPL.getloglikelihood(varinfo), - logjoint=DynamicPPL.getlogjoint(varinfo), + logprior=DynamicPPL.getlogprior(oavi), + loglikelihood=DynamicPPL.getloglikelihood(oavi), + logjoint=DynamicPPL.getlogjoint(oavi), ), ) end diff --git a/src/debug_utils.jl b/src/debug_utils.jl index 9d16d54b1..1bf0df709 100644 --- a/src/debug_utils.jl +++ b/src/debug_utils.jl @@ -335,10 +335,9 @@ function check_model_post_evaluation(acc::DebugAccumulator) end """ - check_model_and_trace(model::Model, varinfo::AbstractVarInfo; error_on_failure=false) + check_model_and_trace([rng::Random.AbstractRNG,] model::Model; error_on_failure=false) -Check that evaluating `model` with the given `varinfo` is valid, warning about any potential -issues. +Check that sampling from the prior of `model`, warning about any potential issues. This will check the model for the following issues: @@ -360,16 +359,14 @@ This will check the model for the following issues: ## Correct model ```jldoctest check-model-and-tracecheck-model-and-trace; setup=:(using Distributions) -julia> using StableRNGs - -julia> rng = StableRNG(42); +julia> using StableRNGs; rng = StableRNG(42); julia> @model demo_correct() = x ~ Normal() demo_correct (generic function with 2 methods) -julia> model = demo_correct(); varinfo = VarInfo(rng, model); +julia> model = demo_correct(); -julia> issuccess, trace = check_model_and_trace(model, varinfo); +julia> issuccess, trace = check_model_and_trace(rng, model); julia> issuccess true @@ -379,7 +376,7 @@ julia> print(trace) julia> cond_model = model | (x = 1.0,); -julia> issuccess, trace = check_model_and_trace(cond_model, VarInfo(cond_model)); +julia> issuccess, trace = check_model_and_trace(cond_model); ┌ Warning: The model does not contain any parameters. └ @ DynamicPPL.DebugUtils DynamicPPL.jl/src/debug_utils.jl:342 @@ -404,26 +401,25 @@ julia> # Notice that VarInfo(model_incorrect) evaluates the model, but doesn't a # alert us to the issue of `x` being sampled twice. model = demo_incorrect(); varinfo = VarInfo(model); -julia> issuccess, trace = check_model_and_trace(model, varinfo; error_on_failure=true); +julia> issuccess, trace = check_model_and_trace(model; error_on_failure=true); ERROR: varname x used multiple times in model ``` """ function check_model_and_trace( - model::Model, varinfo::AbstractVarInfo; error_on_failure=false + rng::Random.AbstractRNG, model::Model; error_on_failure=false ) - # Add debug accumulator to the VarInfo. - varinfo = DynamicPPL.setaccs!!(deepcopy(varinfo), (DebugAccumulator(error_on_failure),)) - # Perform checks before evaluating the model. issuccess = check_model_pre_evaluation(model) # TODO(penelopeysm): Implement merge, etc. for DebugAccumulator, and then perform a # check on the merged accumulator, rather than checking it in the accumulate_assume # calls. That way we can also correctly support multi-threaded evaluation. - _, varinfo = DynamicPPL.evaluate!!(model, varinfo) + oavi = DynamicPPL.OnlyAccsVarInfo((DebugAccumulator(error_on_failure),)) + init_strategy = InitFromPrior() + _, oavi = DynamicPPL.init!!(rng, model, oavi, init_strategy, UnlinkAll()) # Perform checks after evaluating the model. - debug_acc = DynamicPPL.getacc(varinfo, Val(_DEBUG_ACC_NAME)) + debug_acc = DynamicPPL.getacc(oavi, Val(_DEBUG_ACC_NAME)) issuccess = issuccess && check_model_post_evaluation(debug_acc) if !issuccess && error_on_failure @@ -433,9 +429,14 @@ function check_model_and_trace( trace = debug_acc.statements return issuccess, trace end +function check_model_and_trace(model::Model; error_on_failure=false) + return check_model_and_trace( + Random.default_rng(), model; error_on_failure=error_on_failure + ) +end """ - check_model(model::Model, varinfo::AbstractVarInfo; error_on_failure=false) + check_model(model::Model; error_on_failure=false) Check that `model` is valid, warning about any potential issues (or erroring if `error_on_failure` is `true`). @@ -443,8 +444,11 @@ Check that `model` is valid, warning about any potential issues (or erroring if # Returns - `issuccess::Bool`: Whether the model check succeeded. """ -check_model(model::Model, varinfo::AbstractVarInfo; error_on_failure=false) = - first(check_model_and_trace(model, varinfo; error_on_failure=error_on_failure)) +check_model(rng::Random.AbstractRNG, model::Model; error_on_failure=false) = + first(check_model_and_trace(rng, model; error_on_failure=error_on_failure)) +function check_model(model::Model; error_on_failure=false) + return check_model(Random.default_rng(), model; error_on_failure=error_on_failure) +end # Convenience method used to check if all elements in a list are the same. function all_the_same(xs) @@ -479,11 +483,8 @@ and checking if the model is consistent across runs. function has_static_constraints( rng::Random.AbstractRNG, model::Model; num_evals::Int=5, error_on_failure::Bool=false ) - new_model = DynamicPPL.contextualize( - model, InitContext(rng, InitFromPrior(), UnlinkAll()) - ) results = map(1:num_evals) do _ - check_model_and_trace(new_model, VarInfo(); error_on_failure=error_on_failure) + check_model_and_trace(rng, model; error_on_failure=error_on_failure) end # Extract the distributions and the corresponding bijectors for each run. diff --git a/src/model.jl b/src/model.jl index 123c361be..45fd74786 100644 --- a/src/model.jl +++ b/src/model.jl @@ -762,8 +762,10 @@ function (model::Model)(varinfo::AbstractVarInfo) end # ^ Weird Documenter.jl bug means that we have to write the two above separately # as it can only detect the `function`-less syntax. -function (model::Model)(rng::Random.AbstractRNG, varinfo::AbstractVarInfo=VarInfo()) - return first(init!!(rng, model, varinfo)) +function (model::Model)( + rng::Random.AbstractRNG, varinfo::AbstractVarInfo=OnlyAccsVarInfo(()) +) + return first(init!!(rng, model, varinfo, InitFromPrior(), UnlinkAll())) end """ @@ -771,7 +773,7 @@ end [rng::Random.AbstractRNG,] model::Model, varinfo::AbstractVarInfo, - [init_strategy::AbstractInitStrategy=InitFromPrior(),] + init_strategy::AbstractInitStrategy, [transform_strategy::AbstractTransformStrategy=get_transform_strategy(varinfo),] ) @@ -779,14 +781,13 @@ Evaluate the `model` and replace the values of the model's random variables in t `varinfo` with new values, using a specified initialisation strategy. If the values in `varinfo` are not set, they will be added using a specified initialisation strategy. -If `init_strategy` is not provided, defaults to `InitFromPrior()`. - -`transform_strategy` tells the model evaluation whether variables should be interpreted as linked -or unlinked. Right now, it is slightly complicated because the default behaviour depends on -the `varinfo` provided. If `varinfo isa VarInfo`, then the transform strategy is inferred -from the VarInfo, i.e., linked variables in the VarInfo are treated as linked during -evaluation. Conversely, if `varinfo isa OnlyAccsVarInfo`, then all variables are treated as -unlinked. +`transform_strategy` tells the model evaluation whether variables should be interpreted as +linked or unlinked. Right now, it is slightly complicated because the default behaviour +depends on the `varinfo` provided. If `varinfo isa VarInfo`, then the transform strategy is +inferred from the VarInfo, i.e., linked variables in the VarInfo are treated as linked +during evaluation. Conversely, if `varinfo isa OnlyAccsVarInfo`, then you must specify the +transform strategy explicitly, since an `OnlyAccsVarInfo` does not contain any information +about which variables are transformed. Returns a tuple of the model's return value, plus the updated `varinfo` object. """ @@ -794,34 +795,89 @@ function init!!( rng::Random.AbstractRNG, model::Model, vi::AbstractVarInfo, - strategy::AbstractInitStrategy=InitFromPrior(), + init_strategy::AbstractInitStrategy, transform_strategy::AbstractTransformStrategy=get_transform_strategy(vi), ) - ctx = InitContext(rng, strategy, transform_strategy) + ctx = InitContext(rng, init_strategy, transform_strategy) model = DynamicPPL.setleafcontext(model, ctx) - return DynamicPPL.evaluate!!(model, vi) + return DynamicPPL.evaluate_nowarn!!(model, vi) end function init!!( model::Model, vi::AbstractVarInfo, - strategy::AbstractInitStrategy=InitFromPrior(), + init_strategy::AbstractInitStrategy=InitFromPrior(), transform_strategy::AbstractTransformStrategy=get_transform_strategy(vi), ) - return init!!(Random.default_rng(), model, vi, strategy, transform_strategy) + return init!!(Random.default_rng(), model, vi, init_strategy, transform_strategy) end """ evaluate!!(model::Model, varinfo) -Evaluate the `model` with the given `varinfo`. +Evaluate the `model` with the given `varinfo`, wrapping it in a `ThreadSafeVarInfo` if the +model is marked as needing threadsafe evaluation. + +!!! warning + The semantics of this method are complicated. We **strongly** recommend that users do + *not* use this method unless absolutely necessary. In the future this method will be + deprecated and removed. As far as possible (and it should **always** be possible -- + please open an issue if you do not know how to adapt your code!) you should use the + five-argument `init!!([rng,] model, ::OnlyAccsVarInfo, init_strategy, + transform_strategy)` method, which has more explicit semantics and allows you to have + more control over each part of the evaluation process. + +The exact semantics depend on the `model`'s context. Fundamentally, this method executes the +model evaluation function (i.e., the function used to define the model) using the given +`varinfo` as an argument. At each tilde-statement, `tilde_assume!!` or `tilde_observe!!` is +called, whose behaviour depends on the model's context. + +Broadly speaking, if the leaf context is an `InitContext`, then this function: + +- uses the initialisation strategy inside the `InitContext`; +- uses the transform strategy inside the `InitContext`; +- uses the accumulators inside `varinfo` (resetting them before evaluation); +- overwrites the values in `varinfo` with the new values obtained from the initialisation strategy. + +If the leaf context is a `DefaultContext`, then this function: -If the model has been marked as requiring threadsafe evaluation, are available, the varinfo -provided will be wrapped in a `ThreadSafeVarInfo` before evaluation. +- uses the values inside the `varinfo` as the initialisation strategy; +- derives a transform strategy from the `varinfo`'s stored variables (if a linked variable is + stored, then the transform strategy will treat that variable as linked; likewise for + unlinked) +- uses the accumulators inside `varinfo` (resetting them before evaluation); +- does not overwrite the values in the `varinfo` (that is unnecessary since the values used + for evaluation are already stored in `varinfo`). -Returns a tuple of the model's return value, plus the updated `varinfo` -(unwrapped if necessary). +The long-term plan for this method is to: + +- Replace `DefaultContext` with `InitContext` by splitting up the functionality of `DefaultContext` + into its constituent components +- Remove the `VarInfo` argument, and instead use only an `AccumulatorTuple` +- Separate the initialisation and transform strategies into separate arguments, instead of storing + them inside the model's context. """ function AbstractPPL.evaluate!!(model::Model, varinfo::AbstractVarInfo) + @warn ( + "Calling `evaluate!!(model, varinfo)` directly is not recommended and will be" * + " deprecated in the future. Please switch to using `init!!([rng,] model," * + " ::OnlyAccsVarInfo, init_strategy, transform_strategy)` instead, which" * + " has more explicit semantics and allows you to have more control over each" * + " part of the evaluation process. Please see the DynamicPPL documentation" * + " for more details: https://turinglang.org/DynamicPPL.jl/stable/evaluation" + ) maxlog = 5 + return DynamicPPL.evaluate_nowarn!!(model, varinfo) +end + +""" + evaluate_nowarn!!(model::Model, varinfo) + +This is the same as `evaluate!!(model, varinfo)` but without the deprecation warning. + +!!! warning + This is meant for internal use in DynamicPPL.jl only! If you rely on this method in your + code, please note that it may break at any time. +""" +function evaluate_nowarn!!(model::Model, varinfo::AbstractVarInfo) return if requires_threadsafe(model) # Use of float_type_with_fallback(eltype(x)) is necessary to deal with cases where x is # a gradient type of some AD backend. @@ -987,15 +1043,15 @@ julia> # Truth. -9902.33787706641 ``` """ -function logjoint(model::Model, varinfo::AbstractVarInfo) - return getlogjoint(last(evaluate!!(model, varinfo))) -end function logjoint(model::Model, params) vi = OnlyAccsVarInfo( AccumulatorTuple(LogPriorAccumulator(), LogLikelihoodAccumulator()) ) - ctx = InitFromParams(params, nothing) - return getlogjoint(last(init!!(model, vi, ctx, UnlinkAll()))) + init_strategy = InitFromParams(params, nothing) + return getlogjoint(last(init!!(model, vi, init_strategy, UnlinkAll()))) +end +function logjoint(model::Model, varinfo::AbstractVarInfo) + return logjoint(model, get_values(varinfo)) end """ @@ -1033,20 +1089,13 @@ julia> # Truth. -5000.918938533205 ``` """ -function logprior(model::Model, varinfo::AbstractVarInfo) - # Remove other accumulators from varinfo, since they are unnecessary. - logprioracc = if hasacc(varinfo, Val(:LogPrior)) - getacc(varinfo, Val(:LogPrior)) - else - LogPriorAccumulator() - end - varinfo = setaccs!!(deepcopy(varinfo), (logprioracc,)) - return getlogprior(last(evaluate!!(model, varinfo))) -end function logprior(model::Model, params) vi = OnlyAccsVarInfo(AccumulatorTuple(LogPriorAccumulator())) - ctx = InitFromParams(params, nothing) - return getlogprior(last(init!!(model, vi, ctx, UnlinkAll()))) + init_strategy = InitFromParams(params, nothing) + return getlogprior(last(init!!(model, vi, init_strategy, UnlinkAll()))) +end +function logprior(model::Model, varinfo::AbstractVarInfo) + return logprior(model, get_values(varinfo)) end """ @@ -1080,131 +1129,13 @@ julia> # Truth. logpdf(Normal(100.0, 1.0), 1.0) -4901.418938533205 """ -function Distributions.loglikelihood(model::Model, varinfo::AbstractVarInfo) - # Remove other accumulators from varinfo, since they are unnecessary. - loglikelihoodacc = if hasacc(varinfo, Val(:LogLikelihood)) - getacc(varinfo, Val(:LogLikelihood)) - else - LogLikelihoodAccumulator() - end - varinfo = setaccs!!(deepcopy(varinfo), (loglikelihoodacc,)) - return getloglikelihood(last(evaluate!!(model, varinfo))) -end function Distributions.loglikelihood(model::Model, params) vi = OnlyAccsVarInfo(AccumulatorTuple(LogLikelihoodAccumulator())) - ctx = InitFromParams(params, nothing) - return getloglikelihood(last(init!!(model, vi, ctx, UnlinkAll()))) -end - -""" - logjoint(model::Model, values::Union{NamedTuple,AbstractDict}) - -Return the log joint probability of variables `values` for the probabilistic `model`. - -See [`logprior`](@ref) and [`loglikelihood`](@ref). - -# Examples -```jldoctest; setup=:(using Distributions) -julia> @model function demo(x) - m ~ Normal() - for i in eachindex(x) - x[i] ~ Normal(m, 1.0) - end - end -demo (generic function with 2 methods) - -julia> # Using a `NamedTuple`. - logjoint(demo([1.0]), (m = 100.0, )) --9902.33787706641 - -julia> # Using a `OrderedDict`. - logjoint(demo([1.0]), OrderedDict(@varname(m) => 100.0)) --9902.33787706641 - -julia> # Truth. - logpdf(Normal(100.0, 1.0), 1.0) + logpdf(Normal(), 100.0) --9902.33787706641 -``` -""" -function logjoint(model::Model, values::Union{NamedTuple,AbstractDict}) - accs = AccumulatorTuple((LogPriorAccumulator(), LogLikelihoodAccumulator())) - vi = OnlyAccsVarInfo(accs) - _, vi = DynamicPPL.init!!(model, vi, InitFromParams(values, nothing), UnlinkAll()) - return getlogjoint(vi) -end - -""" - logprior(model::Model, values::Union{NamedTuple,AbstractDict}) - -Return the log prior probability of variables `values` for the probabilistic `model`. - -See also [`logjoint`](@ref) and [`loglikelihood`](@ref). - -# Examples -```jldoctest; setup=:(using Distributions) -julia> @model function demo(x) - m ~ Normal() - for i in eachindex(x) - x[i] ~ Normal(m, 1.0) - end - end -demo (generic function with 2 methods) - -julia> # Using a `NamedTuple`. - logprior(demo([1.0]), (m = 100.0, )) --5000.918938533205 - -julia> # Using a `OrderedDict`. - logprior(demo([1.0]), OrderedDict(@varname(m) => 100.0)) --5000.918938533205 - -julia> # Truth. - logpdf(Normal(), 100.0) --5000.918938533205 -``` -""" -function logprior(model::Model, values::Union{NamedTuple,AbstractDict}) - accs = AccumulatorTuple((LogPriorAccumulator(),)) - vi = OnlyAccsVarInfo(accs) - _, vi = DynamicPPL.init!!(model, vi, InitFromParams(values, nothing), UnlinkAll()) - return getlogprior(vi) + init_strategy = InitFromParams(params, nothing) + return getloglikelihood(last(init!!(model, vi, init_strategy, UnlinkAll()))) end - -""" - loglikelihood(model::Model, values::Union{NamedTuple,AbstractDict}) - -Return the log likelihood of variables `values` for the probabilistic `model`. - -See also [`logjoint`](@ref) and [`logprior`](@ref). - -# Examples -```jldoctest; setup=:(using Distributions) -julia> @model function demo(x) - m ~ Normal() - for i in eachindex(x) - x[i] ~ Normal(m, 1.0) - end - end -demo (generic function with 2 methods) - -julia> # Using a `NamedTuple`. - loglikelihood(demo([1.0]), (m = 100.0, )) --4901.418938533205 - -julia> # Using a `OrderedDict`. - loglikelihood(demo([1.0]), OrderedDict(@varname(m) => 100.0)) --4901.418938533205 - -julia> # Truth. - logpdf(Normal(100.0, 1.0), 1.0) --4901.418938533205 -``` -""" -function Distributions.loglikelihood(model::Model, values::Union{NamedTuple,AbstractDict}) - accs = AccumulatorTuple((LogLikelihoodAccumulator(),)) - vi = OnlyAccsVarInfo(accs) - _, vi = DynamicPPL.init!!(model, vi, InitFromParams(values, nothing), UnlinkAll()) - return getloglikelihood(vi) +function Distributions.loglikelihood(model::Model, varinfo::AbstractVarInfo) + return loglikelihood(model, get_values(varinfo)) end # Implemented & documented in DynamicPPLMCMCChainsExt diff --git a/src/test_utils/contexts.jl b/src/test_utils/contexts.jl index 19d7b7420..b395df115 100644 --- a/src/test_utils/contexts.jl +++ b/src/test_utils/contexts.jl @@ -37,10 +37,13 @@ function test_leaf_context(context::DynamicPPL.AbstractContext, model::DynamicPP # filled with values. @testset "evaluation" begin # Generate a new filled varinfo - _, vi = DynamicPPL.init!!(model, DynamicPPL.VarInfo()) + vi = DynamicPPL.VarInfo(model) # Set the test context as the new leaf context new_model = DynamicPPL.setleafcontext(model, context) - _, vi = DynamicPPL.evaluate!!(new_model, vi) + # It might seem a bit ugly that we have to use `evaluate_nowarn!!` here. Essentially + # we want to test that low-level evaluation works with the context, so this is the + # right thing to do. + _, vi = DynamicPPL.evaluate_nowarn!!(new_model, vi) @test vi isa DynamicPPL.VarInfo end end @@ -71,10 +74,15 @@ function test_parent_context(context::DynamicPPL.AbstractContext, model::Dynamic new_model = contextualize(model, context) vi = DynamicPPL.VarInfo() # Initialisation - _, vi = DynamicPPL.init!!(new_model, DynamicPPL.VarInfo()) + _, vi = DynamicPPL.init!!( + new_model, + DynamicPPL.VarInfo(), + DynamicPPL.InitFromPrior(), + DynamicPPL.UnlinkAll(), + ) @test vi isa DynamicPPL.VarInfo - # Evaluation - _, vi = DynamicPPL.evaluate!!(new_model, vi) + # Evaluation. See above regarding note about evaluate_nowarn!!. + _, vi = DynamicPPL.evaluate_nowarn!!(new_model, vi) @test vi isa DynamicPPL.VarInfo end end diff --git a/src/test_utils/model_interface.jl b/src/test_utils/model_interface.jl index 50e13f912..7e2322f5d 100644 --- a/src/test_utils/model_interface.jl +++ b/src/test_utils/model_interface.jl @@ -88,13 +88,16 @@ function logprior_true_with_logabsdet_jacobian end Return a collection of `VarName` as they are expected to appear in the model. -Even though it is recommended to implement this by hand for a particular `Model`, -a default implementation using [`VarInfo`](@ref) is provided. +Even though it is recommended to implement this by hand for a particular `Model`, a default +implementation that evaluates the model is provided. """ function varnames(model::Model) - result = collect(keys(last(DynamicPPL.init!!(model, VarInfo())))) + vval_acc = DynamicPPL.VectorValueAccumulator() + oavi = OnlyAccsVarInfo((vval_acc,)) + _, oavi = DynamicPPL.init!!(model, oavi, InitFromPrior(), UnlinkAll()) + vvals = DynamicPPL.getacc(oavi, Val(DynamicPPL.accumulator_name(vval_acc))).values # Concretise the element type. - return [x for x in result] + return [x for x in keys(vvals)] end """ diff --git a/src/threadsafe.jl b/src/threadsafe.jl index b346847aa..1a2711d82 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -28,6 +28,8 @@ function setacc!!(vi::ThreadSafeVarInfo, acc::AbstractAccumulator) return ThreadSafeVarInfo(inner_vi, news_accs_by_thread) end +get_values(vi::ThreadSafeVarInfo) = get_values(vi.varinfo) + # Get both the main accumulator and the thread-specific accumulators of the same type and # combine them. function getacc(vi::ThreadSafeVarInfo, accname::Val) diff --git a/src/varinfo.jl b/src/varinfo.jl index 522d4d47b..52abcd324 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -148,6 +148,8 @@ function VarInfo(model::Model, initstrat::AbstractInitStrategy=InitFromPrior()) return VarInfo(Random.default_rng(), model, initstrat) end +get_values(vi::VarInfo) = vi.values + getaccs(vi::VarInfo) = vi.accs function setaccs!!(vi::VarInfo{Linked}, accs::AccumulatorTuple) where {Linked} return VarInfo{Linked}(vi.values, accs) diff --git a/test/chains.jl b/test/chains.jl index df5a5aeff..d5a4a77fb 100644 --- a/test/chains.jl +++ b/test/chains.jl @@ -59,7 +59,7 @@ using Test @test_throws ErrorException ParamsWithStats(VarInfo(model)) # With VAIM, it should work vi = DynamicPPL.setaccs!!(VarInfo(model), (DynamicPPL.RawValueAccumulator(true),)) - vi = last(DynamicPPL.evaluate!!(model, vi)) + vi = last(DynamicPPL.init!!(model, vi, InitFromPrior(), UnlinkAll())) ps = ParamsWithStats(vi) @test haskey(ps.params, @varname(x)) @test haskey(ps.params, @varname(y)) diff --git a/test/debug_utils.jl b/test/debug_utils.jl index dfa00f71b..4e364d251 100644 --- a/test/debug_utils.jl +++ b/test/debug_utils.jl @@ -9,7 +9,7 @@ using LinearAlgebra: I @testset "check_model" begin @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - issuccess, trace = check_model_and_trace(model, VarInfo(model)) + issuccess, trace = check_model_and_trace(model) # These models should all work. @test issuccess @@ -28,21 +28,22 @@ using LinearAlgebra: I end @testset "multiple usage of same variable" begin + function test_model_can_run_but_fails_check(model) + # Check that it can actually run + @test VarInfo(model) isa VarInfo + # but if you call check_model it should fail + issuccess = check_model(model) + @test !issuccess + @test_throws ErrorException check_model(model; error_on_failure=true) + end + @testset "simple" begin @model function buggy_demo_model() x ~ Normal() x ~ Normal() return y ~ Normal() end - buggy_model = buggy_demo_model() - varinfo = VarInfo(buggy_model) - - @test_logs (:warn,) (:warn,) check_model(buggy_model, varinfo) - issuccess = check_model(buggy_model, varinfo) - @test !issuccess - @test_throws ErrorException check_model( - buggy_model, varinfo; error_on_failure=true - ) + test_model_can_run_but_fails_check(buggy_demo_model()) end @testset "submodel" begin @@ -52,11 +53,7 @@ using LinearAlgebra: I z ~ to_submodel(ModelInner(), false) return x ~ Normal() end - model = ModelOuterBroken() - varinfo = VarInfo(model) - @test_throws ErrorException check_model( - model, VarInfo(model); error_on_failure=true - ) + test_model_can_run_but_fails_check(ModelOuterBroken()) @model function ModelOuterWorking() # With automatic prefixing => `x` is not duplicated. @@ -65,7 +62,7 @@ using LinearAlgebra: I return z end model = ModelOuterWorking() - @test check_model(model, VarInfo(model); error_on_failure=true) + @test check_model(model) # With manual prefixing, https://github.com/TuringLang/DynamicPPL.jl/issues/785 @model function ModelOuterWorking2() @@ -74,7 +71,7 @@ using LinearAlgebra: I return (x1, x2) end model = ModelOuterWorking2() - @test check_model(model, VarInfo(model); error_on_failure=true) + @test check_model(model) end end @@ -86,14 +83,14 @@ using LinearAlgebra: I end end m = demo_nan_in_data([1.0, NaN]) - @test_throws ErrorException check_model(m, VarInfo(m); error_on_failure=true) + @test_throws ErrorException check_model(m; error_on_failure=true) # Test NamedTuples with nested arrays, see #898 @model function demo_nan_complicated(nt) nt ~ product_distribution((x=Normal(), y=Dirichlet([2, 4]))) return x ~ Normal() end m = demo_nan_complicated((x=1.0, y=[NaN, 0.5])) - @test_throws ErrorException check_model(m, VarInfo(m); error_on_failure=true) + @test_throws ErrorException check_model(m; error_on_failure=true) end @testset "incorrect use of condition" begin @@ -102,10 +99,7 @@ using LinearAlgebra: I return x ~ MvNormal(zeros(length(x)), I) end model = demo_missing_in_multivariate([1.0, missing]) - # Have to run this check_model call with an empty varinfo, because actually - # instantiating the VarInfo would cause it to throw a MethodError. - model = contextualize(model, InitContext(InitFromPrior(), UnlinkAll())) - @test_throws ErrorException check_model(model, VarInfo(); error_on_failure=true) + @test_throws ErrorException check_model(model; error_on_failure=true) end @testset "condition both in args and context" begin @@ -119,9 +113,8 @@ using LinearAlgebra: I OrderedDict(@varname(x[1]) => 2.0), ] conditioned_model = DynamicPPL.condition(model, vals) - varinfo = VarInfo(conditioned_model) @test_throws ErrorException check_model( - conditioned_model, varinfo; error_on_failure=true + conditioned_model; error_on_failure=true ) end end @@ -131,7 +124,7 @@ using LinearAlgebra: I @testset "assume" begin @model demo_assume() = x ~ Normal() model = demo_assume() - issuccess, trace = check_model_and_trace(model, VarInfo(model)) + issuccess, trace = check_model_and_trace(model) @test issuccess @test startswith(string(trace), r" assume: x ~ (Distributions\.)?Normal") end @@ -139,7 +132,7 @@ using LinearAlgebra: I @testset "observe" begin @model demo_observe(x) = x ~ Normal() model = demo_observe(1.0) - issuccess, trace = check_model_and_trace(model, VarInfo(model)) + issuccess, trace = check_model_and_trace(model) @test issuccess @test occursin( r"observe: x \(= \d+\.\d+\) ~ (Distributions\.)?Normal", string(trace) @@ -150,8 +143,8 @@ using LinearAlgebra: I @testset "comparing multiple traces" begin # Run the same model but with different VarInfos. model = DynamicPPL.TestUtils.demo_dynamic_constraint() - issuccess_1, trace_1 = check_model_and_trace(model, VarInfo(model)) - issuccess_2, trace_2 = check_model_and_trace(model, VarInfo(model)) + issuccess_1, trace_1 = check_model_and_trace(model) + issuccess_2, trace_2 = check_model_and_trace(model) @test issuccess_1 && issuccess_2 # Should have the same varnames present. @@ -176,7 +169,7 @@ using LinearAlgebra: I end for ns in [(2,), (2, 2), (2, 2, 2)] model = demo_undef(ns...) - @test check_model(model, VarInfo(model); error_on_failure=true) + @test check_model(model; error_on_failure=true) end end diff --git a/test/ext/DynamicPPLMCMCChainsExt.jl b/test/ext/DynamicPPLMCMCChainsExt.jl index 233ce3980..7e58522fa 100644 --- a/test/ext/DynamicPPLMCMCChainsExt.jl +++ b/test/ext/DynamicPPLMCMCChainsExt.jl @@ -6,11 +6,13 @@ using AbstractPPL: AbstractPPL using Random: Random function make_chain_from_prior(rng::Random.AbstractRNG, model::Model, n_iters::Int) - vi = VarInfo(model) - vi = DynamicPPL.setaccs!!(vi, (DynamicPPL.RawValueAccumulator(false),)) + vi = DynamicPPL.OnlyAccsVarInfo(( + DynamicPPL.default_accumulators()..., DynamicPPL.RawValueAccumulator(false) + )) ps = hcat([ - DynamicPPL.ParamsWithStats(last(DynamicPPL.init!!(rng, model, vi))) for - _ in 1:n_iters + DynamicPPL.ParamsWithStats( + last(DynamicPPL.init!!(rng, model, vi, InitFromPrior(), UnlinkAll())) + ) for _ in 1:n_iters ]) return AbstractMCMC.from_samples(MCMCChains.Chains, ps) end diff --git a/test/linking.jl b/test/linking.jl index 80e97c041..3a31b0ee4 100644 --- a/test/linking.jl +++ b/test/linking.jl @@ -94,8 +94,6 @@ end example_values_m_only = (m=example_values.m,) vis = DynamicPPL.TestUtils.setup_varinfos(model, example_values_m_only) @testset "$(short_varinfo_name(vi))" for vi in vis - # Evaluate once to ensure we have `logp` value. - vi = last(DynamicPPL.evaluate!!(model, vi)) vi_linked = if mutable DynamicPPL.link!!(deepcopy(vi), model) else diff --git a/test/model.jl b/test/model.jl index b514e1df7..088dfdc45 100644 --- a/test/model.jl +++ b/test/model.jl @@ -18,11 +18,13 @@ short_varinfo_name(::DynamicPPL.ThreadSafeVarInfo) = "ThreadSafeVarInfo" short_varinfo_name(::DynamicPPL.VarInfo) = "VarInfo" function make_chain_from_prior(rng::Random.AbstractRNG, model::Model, n_iters::Int) - vi = VarInfo(model) - vi = DynamicPPL.setaccs!!(vi, (DynamicPPL.RawValueAccumulator(false),)) + vi = DynamicPPL.OnlyAccsVarInfo(( + DynamicPPL.default_accumulators()..., DynamicPPL.RawValueAccumulator(false) + )) ps = hcat([ - DynamicPPL.ParamsWithStats(last(DynamicPPL.init!!(rng, model, vi))) for - _ in 1:n_iters + DynamicPPL.ParamsWithStats( + last(DynamicPPL.init!!(rng, model, vi, InitFromPrior(), UnlinkAll())) + ) for _ in 1:n_iters ]) return AbstractMCMC.from_samples(MCMCChains.Chains, ps) end @@ -97,6 +99,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() end N = 200 chain = make_chain_from_prior(model, N) + chain = MCMCChains.get_sections(chain, :parameters) logpriors = logprior(model, chain) loglikelihoods = loglikelihood(model, chain) logjoints = logjoint(model, chain) @@ -272,12 +275,8 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() @testset "Internal methods" begin model = GDEMO_DEFAULT - # sample from model and extract variables - vi = VarInfo(model) - - # Second component of return-value of `evaluate!!` should - # be a `DynamicPPL.AbstractVarInfo`. - evaluate_retval = DynamicPPL.evaluate!!(model, vi) + # Second component of return-value of `init!!` should be a `DynamicPPL.AbstractVarInfo`. + evaluate_retval = DynamicPPL.init!!(model, VarInfo(), InitFromPrior(), UnlinkAll()) @test evaluate_retval[2] isa DynamicPPL.AbstractVarInfo # Should not return `AbstractVarInfo` when we call the model. @@ -403,33 +402,31 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() end end - if VERSION >= v"1.8" - @testset "Type stability of models" begin - models_to_test = [ - DynamicPPL.TestUtils.DEMO_MODELS..., DynamicPPL.TestUtils.demo_lkjchol(2) - ] - @testset "$(model.f)" for model in models_to_test - if model.f === DynamicPPL.TestUtils.demo_nested_colons && VERSION < v"1.11" - # On v1.10, the demo_nested_colons model, which uses a lot of - # NamedTuples, is badly type unstable. Not worth doing much about - # it, since it's fixed on later Julia versions, so just skipping - # these tests. - @test false skip = true - continue + @testset "Type stability of models" begin + models_to_test = [ + DynamicPPL.TestUtils.DEMO_MODELS..., DynamicPPL.TestUtils.demo_lkjchol(2) + ] + @testset "$(model.f)" for model in models_to_test + if model.f === DynamicPPL.TestUtils.demo_nested_colons && VERSION < v"1.11" + # On v1.10, the demo_nested_colons model, which uses a lot of + # NamedTuples, is badly type unstable. Not worth doing much about + # it, since it's fixed on later Julia versions, so just skipping + # these tests. + @test false skip = true + continue + end + example_values = DynamicPPL.TestUtils.rand_prior_true(model) + varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values) + @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos + @test begin + @inferred(DynamicPPL.evaluate_nowarn!!(model, varinfo)) + true end - example_values = DynamicPPL.TestUtils.rand_prior_true(model) - varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values) - @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - @test begin - @inferred(DynamicPPL.evaluate!!(model, varinfo)) - true - end - - varinfo_linked = DynamicPPL.link(varinfo, model) - @test begin - @inferred(DynamicPPL.evaluate!!(model, varinfo_linked)) - true - end + + varinfo_linked = DynamicPPL.link(varinfo, model) + @test begin + @inferred(DynamicPPL.evaluate_nowarn!!(model, varinfo_linked)) + true end end end @@ -509,7 +506,11 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() varinfo = DynamicPPL.VarInfo(model) logjoint = getlogjoint(varinfo) # unlinked space varinfo_linked = DynamicPPL.link(varinfo, model) - varinfo_linked_result = last(DynamicPPL.evaluate!!(model, deepcopy(varinfo_linked))) + varinfo_linked_result = last( + DynamicPPL.init!!( + model, VarInfo(), InitFromParams(varinfo_linked.values, nothing), LinkAll() + ), + ) # getlogjoint should return the same result as before it was linked @test getlogjoint(varinfo_linked) ≈ getlogjoint(varinfo_linked_result) @test getlogjoint(varinfo_linked) ≈ logjoint diff --git a/test/threadsafe.jl b/test/threadsafe.jl index c4be28ffc..c7e30a688 100644 --- a/test/threadsafe.jl +++ b/test/threadsafe.jl @@ -83,7 +83,7 @@ const gdemo_default = gdemo_d() # But init!! should return the original VarInfo @test vi isa DynamicPPL.VarInfo # Same with evaluate!! - _, vi = DynamicPPL.evaluate!!(model, vi) + _, vi = DynamicPPL.evaluate_nowarn!!(model, vi) @test vi_ isa DynamicPPL.ThreadSafeVarInfo @test vi isa DynamicPPL.VarInfo end diff --git a/test/utils.jl b/test/utils.jl index 3445a72da..a4a88990d 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -107,19 +107,19 @@ end model = test() vi_unlinked = VarInfo(model) vi_linked = DynamicPPL.link!!(VarInfo(model), model) - @test (DynamicPPL.evaluate!!(model, vi_unlinked); true) - @test (DynamicPPL.evaluate!!(model, vi_linked); true) + @test (DynamicPPL.evaluate_nowarn!!(model, vi_unlinked); true) + @test (DynamicPPL.evaluate_nowarn!!(model, vi_linked); true) model_init = DynamicPPL.setleafcontext( model, DynamicPPL.InitContext(DynamicPPL.InitFromPrior(), DynamicPPL.UnlinkAll()), ) - @test (DynamicPPL.evaluate!!(model_init, vi_unlinked); true) + @test (DynamicPPL.evaluate_nowarn!!(model_init, vi_unlinked); true) model_init = DynamicPPL.setleafcontext( model, DynamicPPL.InitContext(DynamicPPL.InitFromPrior(), DynamicPPL.LinkAll()), ) - @test (DynamicPPL.evaluate!!(model_init, vi_linked); true) + @test (DynamicPPL.evaluate_nowarn!!(model_init, vi_linked); true) end # Unconstrained univariate diff --git a/test/varinfo.jl b/test/varinfo.jl index ecca93f34..18f25e27a 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -32,11 +32,13 @@ short_varinfo_name(::DynamicPPL.ThreadSafeVarInfo) = "ThreadSafeVarInfo" short_varinfo_name(::DynamicPPL.VarInfo) = "VarInfo" function make_chain_from_prior(rng::Random.AbstractRNG, model::Model, n_iters::Int) - vi = VarInfo(model) - vi = DynamicPPL.setaccs!!(vi, (DynamicPPL.RawValueAccumulator(false),)) + vi = DynamicPPL.OnlyAccsVarInfo(( + DynamicPPL.default_accumulators()..., DynamicPPL.RawValueAccumulator(false) + )) ps = hcat([ - DynamicPPL.ParamsWithStats(last(DynamicPPL.init!!(rng, model, vi))) for - _ in 1:n_iters + DynamicPPL.ParamsWithStats( + last(DynamicPPL.init!!(rng, model, vi, InitFromPrior(), UnlinkAll())) + ) for _ in 1:n_iters ]) return AbstractMCMC.from_samples(MCMCChains.Chains, ps) end @@ -114,7 +116,7 @@ end vi = DynamicPPL.unflatten!!(VarInfo(m), collect(values)) - vi = last(DynamicPPL.evaluate!!(m, deepcopy(vi))) + vi = last(DynamicPPL.evaluate_nowarn!!(m, deepcopy(vi))) @test getlogprior(vi) == lp_a + lp_b @test getlogjac(vi) == 0.0 @test getloglikelihood(vi) == lp_c + lp_d @@ -151,7 +153,7 @@ end @test getlogp(setlogp!!(vi, getlogp(vi))) == getlogp(vi) vi = last( - DynamicPPL.evaluate!!( + DynamicPPL.evaluate_nowarn!!( m, DynamicPPL.setaccs!!(deepcopy(vi), (LogPriorAccumulator(),)) ), ) @@ -171,7 +173,7 @@ end end # Test evaluating without any accumulators. - vi = last(DynamicPPL.evaluate!!(m, DynamicPPL.setaccs!!(deepcopy(vi), ()))) + vi = last(DynamicPPL.evaluate_nowarn!!(m, DynamicPPL.setaccs!!(deepcopy(vi), ()))) # need regex because 1.11 and 1.12 throw different errors (in 1.12 the # missing field is surrounded by backticks) @test_throws r"has no field `?LogPrior" getlogprior(vi) @@ -182,7 +184,7 @@ end @testset "resetaccs" begin # Put in a bunch of accumulators, check that they're all reset either - # when we call resetaccs!!, empty!!, or evaluate!!. + # when we call resetaccs!!, empty!!, or evaluate_nowarn!!. @model function demo() a ~ Normal() return x ~ Normal(a) @@ -197,7 +199,7 @@ end vi_orig, DynamicPPL.PointwiseLogProbAccumulator{:both}() ) # And evaluate the model once so that they are populated. - _, vi_orig = DynamicPPL.evaluate!!(model, vi_orig) + _, vi_orig = DynamicPPL.evaluate_nowarn!!(model, vi_orig) function all_accs_empty(vi::AbstractVarInfo) for acc_key in keys(DynamicPPL.getaccs(vi)) @@ -241,7 +243,7 @@ end @test all_accs_same(vi_orig, deepcopy(vi_orig)) # If we re-evaluate, then we expect the accs to be reset prior to evaluation. # Thus after re-evaluation, the accs should be exactly the same as before. - _, vi = DynamicPPL.evaluate!!(model, deepcopy(vi_orig)) + _, vi = DynamicPPL.evaluate_nowarn!!(model, deepcopy(vi_orig)) @test all_accs_same(vi, vi_orig) end @@ -409,9 +411,6 @@ end model, value_true; include_threadsafe=true ) @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - # Evaluate the model once to update the logp of the varinfo. - varinfo = last(DynamicPPL.evaluate!!(model, varinfo)) - varinfo_linked = if mutating DynamicPPL.link!!(deepcopy(varinfo), model) else