Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,10 @@ Please see the docstring for details.
Each initialisation strategy can decide what kind of `AbstractTransformedValue` to return.
This has no impact on whether the log-Jacobian is calculated or not, as that is determined by the *transform strategy* (see below).

### Transform strategies
### `init!!` and transform strategies

The initialisation strategy argument to `init!!` used to default to `InitFromPrior()`.
It is now mandatory to specify this explicitly.

When using `InitContext`, you can (and indeed sometimes must) now specify a *transform strategy* which controls whether values are interpreted as being in transformed space or not.
This in turn controls whether:
Expand Down Expand Up @@ -107,7 +110,7 @@ In its place, you should directly use the accumulator API to:
To do so, we now export a convenience function `get_raw_values(::AbstractVarInfo)` that will get the stored `VarNamedTuple` of raw values.
This is exactly analogous to how `getlogprior(::AbstractVarInfo)` extracts the log-prior from a `LogPriorAccumulator`.

### Function signature changes
### Function signature changes in tilde-pipeline

`tilde_assume!!` and `accumulate_assume!!` now take extra arguments.

Expand All @@ -122,6 +125,12 @@ In particular
`tval` is either the `AbstractTransformedValue` that `DynamicPPL.init` provided (for InitContext), or the `AbstractTransformedValue` found inside the VarInfo (for DefaultContext).
- `accumulate_assume!!(vi::AbstractVarInfo, val, logjac, vn, dist)` is now `accumulate_assume!!(vi, val, tval, logjac, vn, dist, template)`.

### `DynamicPPL.DebugUtils`

The signature of `DynamicPPL.DebugUtils.check_model` and `DynamicPPL.DebugUtils.check_model_and_trace` are now changed.
Instead of taking a `VarInfo` as the second argument, they now do not need a `VarInfo` at all; they simply sample from the prior of the model.
To make this reproducible you can optionally pass `rng` as a first argument (before the model).

### Overhaul of `VarInfo`

DynamicPPL tracks variable values during model execution using one of the `AbstractVarInfo` types.
Expand Down Expand Up @@ -214,6 +223,19 @@ For example, carrying on from the above, `conditioned(f() | vnt)` will return `v
The underlying code for `ConditionContext` and `FixedContext` is almost completely the same.
In this release, to reduce code duplication, they have been merged into a single implementation, `CondFixContext{Condition}` and `CondFixContext{Fix}`, where the type parameter controls whether conditioning or fixing is performed.

### `DynamicPPL.evaluate!!(model, varinfo)` now warns

This method has very complicated semantics; it's difficult to use properly.
In DynamicPPL we are moving away from trying to encode all the different ways of evaluating a model in the `varinfo` object, and in a future release of DynamicPPL this method will be removed entirely.

For now, the method still exists, but we would like to strongly encourage users to avoid using this method.
In place you should use `init!!([rng,] model, oavi::OnlyAccsVarInfo, init_strategy, transform_strategy)` instead, which is much more explicit, and more closely matches what DynamicPPL.jl will use exclusively in the future.

If you are using this function and are unsure how to adapt your code, please:

1. Read the documentation! There is a *lot* more documentation at https://turinglang.org/DynamicPPL.jl/v0.40/.
2. If you can't figure it out, please open an issue. We are happy to help.

### Accumulator interface exports more functions

To define your own accumulator, you have to overload a number of functions.
Expand Down
11 changes: 6 additions & 5 deletions benchmarks/src/DynamicPPLBenchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@ export Models, benchmark, model_dimension
Return the dimension of `model`, accounting for linking, if any.
"""
function model_dimension(model, islinked)
vi = VarInfo()
vi = last(DynamicPPL.init!!(StableRNG(23), model, vi))
if islinked
vi = DynamicPPL.link(vi, model)
end
tfm_strategy = islinked ? DynamicPPL.LinkAll() : DynamicPPL.UnlinkAll()
vi = last(
DynamicPPL.init!!(
StableRNG(23), model, VarInfo(), DynamicPPL.InitFromPrior(), tfm_strategy
),
)
return length(vi[:])
end

Expand Down
2 changes: 1 addition & 1 deletion docs/src/accumulators.md
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ Because the accumulation process is not always commutative, you may in general e
However, for many accumulators such as log-probability accumulators, this is not an issue.

We can see this in action if we step through the internal DynamicPPL calls.
(Note that calling `DynamicPPL.evaluate!!` on a model where thread-safe mode has been enabled will automatically perform these steps for you.)
(Note that calling `DynamicPPL.init!!` on a model where thread-safe mode has been enabled will automatically perform these steps for you.)

```@example 1
Threads.nthreads()
Expand Down
18 changes: 15 additions & 3 deletions ext/DynamicPPLMarginalLogDensitiesExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,21 @@ VarInfo used in the marginalisation.
!!! note

The other fields of the VarInfo, e.g. accumulated log-probabilities, will not be
updated. If you wish to have a fully consistent VarInfo, you should re-evaluate the
model with the returned VarInfo (e.g. using `vi = last(DynamicPPL.evaluate!!(model,
vi))`).
updated. If you wish to obtain updated log-probabilities, you should re-evaluate the
model with the values inside the returned VarInfo, for example using:

```julia
init_strategy = DynamicPPL.InitFromParams(varinfo.values, nothing)
oavi = DynamicPPL.OnlyAccsVarInfo((
DynamicPPL.LogPriorAccumulator(),
DynamicPPL.LogLikelihoodAccumulator(),
DynamicPPL.RawValueAccumulator(false),
# ... whatever else you need
))
_, oavi = DynamicPPL.init!!(rng, model, oavi, init_strategy, DynamicPPL.UnlinkAll())
```

You can then extract all the updated data from `oavi`.

## Example

Expand Down
15 changes: 0 additions & 15 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -289,21 +289,6 @@ if isdefined(Base.Experimental, :register_error_hint)
)
end
end

Base.Experimental.register_error_hint(MethodError) do io, exc, argtypes, _
is_evaluate_three_arg =
exc.f === AbstractPPL.evaluate!! &&
length(argtypes) == 3 &&
argtypes[1] <: Model &&
argtypes[2] <: AbstractVarInfo &&
argtypes[3] <: AbstractContext
if is_evaluate_three_arg
print(
io,
"\n\nThe method `evaluate!!(model, varinfo, new_ctx)` has been removed. Instead, you should store the `new_ctx` in the `model.context` field using `new_model = contextualize(model, new_ctx)`, and then call `evaluate!!(new_model, varinfo)` on the new model. (Note that, if the model already contained a non-default context, you will need to wrap the existing context.)",
)
end
end
end
end

Expand Down
9 changes: 9 additions & 0 deletions src/abstract_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,15 @@ function setaccs!!(vi::AbstractVarInfo, accs::NTuple{N,AbstractAccumulator}) whe
return setaccs!!(vi, AccumulatorTuple(accs))
end

"""
get_values(vi::AbstractVarInfo)

Return the `VarNamedTuple` in `vi` that stores the variables' values.

This should be implemented by each subtype of `AbstractVarInfo`.
"""
function get_values end

"""
getaccs(vi::AbstractVarInfo)

Expand Down
7 changes: 4 additions & 3 deletions src/accumulators/pointwise_logdensities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,10 @@ function pointwise_logdensities(
model::Model, varinfo::AbstractVarInfo, ::Val{whichlogprob}=Val(:both)
) where {whichlogprob}
AccType = PointwiseLogProbAccumulator{whichlogprob}
varinfo = setaccs!!(varinfo, (AccType(),))
varinfo = last(evaluate!!(model, varinfo))
return getacc(varinfo, Val(accumulator_name(AccType))).logps
oavi = OnlyAccsVarInfo((AccType(),))
init_strategy = InitFromParams(varinfo.values, nothing)
oavi = last(init!!(model, oavi, init_strategy, UnlinkAll()))
return getacc(oavi, Val(accumulator_name(AccType))).logps
end

function pointwise_loglikelihoods(model::Model, varinfo::AbstractVarInfo)
Expand Down
5 changes: 3 additions & 2 deletions src/accumulators/priors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ This is done by evaluating the model at the values present in `varinfo`
and recording the distributions that are present at each tilde statement.
"""
function extract_priors(model::Model, varinfo::AbstractVarInfo)
varinfo = setaccs!!(deepcopy(varinfo), (PriorDistributionAccumulator(),))
varinfo = last(evaluate!!(model, varinfo))
oavi = OnlyAccsVarInfo((PriorDistributionAccumulator(),))
init_strategy = InitFromParams(varinfo.values, nothing)
varinfo = last(init!!(model, oavi, init_strategy, UnlinkAll()))
return getacc(varinfo, Val(PRIOR_ACCNAME)).values
end
13 changes: 7 additions & 6 deletions src/chains.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,17 @@ function ParamsWithStats(
else
(DynamicPPL.RawValueAccumulator(include_colon_eq),)
end
varinfo = DynamicPPL.setaccs!!(varinfo, accs)
varinfo = last(DynamicPPL.evaluate!!(model, varinfo))
params = get_raw_values(varinfo)
oavi = OnlyAccsVarInfo(accs)
init = InitFromParams(varinfo.values, nothing)
oavi = last(DynamicPPL.init!!(model, oavi, init, UnlinkAll()))
params = get_raw_values(oavi)
if include_log_probs
stats = merge(
stats,
(
logprior=DynamicPPL.getlogprior(varinfo),
loglikelihood=DynamicPPL.getloglikelihood(varinfo),
logjoint=DynamicPPL.getlogjoint(varinfo),
logprior=DynamicPPL.getlogprior(oavi),
loglikelihood=DynamicPPL.getloglikelihood(oavi),
logjoint=DynamicPPL.getlogjoint(oavi),
),
)
end
Expand Down
47 changes: 24 additions & 23 deletions src/debug_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -335,10 +335,9 @@ function check_model_post_evaluation(acc::DebugAccumulator)
end

"""
check_model_and_trace(model::Model, varinfo::AbstractVarInfo; error_on_failure=false)
check_model_and_trace([rng::Random.AbstractRNG,] model::Model; error_on_failure=false)

Check that evaluating `model` with the given `varinfo` is valid, warning about any potential
issues.
Check that sampling from the prior of `model`, warning about any potential issues.

This will check the model for the following issues:

Expand All @@ -360,16 +359,14 @@ This will check the model for the following issues:
## Correct model

```jldoctest check-model-and-tracecheck-model-and-trace; setup=:(using Distributions)
julia> using StableRNGs

julia> rng = StableRNG(42);
julia> using StableRNGs; rng = StableRNG(42);

julia> @model demo_correct() = x ~ Normal()
demo_correct (generic function with 2 methods)

julia> model = demo_correct(); varinfo = VarInfo(rng, model);
julia> model = demo_correct();

julia> issuccess, trace = check_model_and_trace(model, varinfo);
julia> issuccess, trace = check_model_and_trace(rng, model);

julia> issuccess
true
Expand All @@ -379,7 +376,7 @@ julia> print(trace)

julia> cond_model = model | (x = 1.0,);

julia> issuccess, trace = check_model_and_trace(cond_model, VarInfo(cond_model));
julia> issuccess, trace = check_model_and_trace(cond_model);
┌ Warning: The model does not contain any parameters.
└ @ DynamicPPL.DebugUtils DynamicPPL.jl/src/debug_utils.jl:342

Expand All @@ -404,26 +401,25 @@ julia> # Notice that VarInfo(model_incorrect) evaluates the model, but doesn't a
# alert us to the issue of `x` being sampled twice.
model = demo_incorrect(); varinfo = VarInfo(model);

julia> issuccess, trace = check_model_and_trace(model, varinfo; error_on_failure=true);
julia> issuccess, trace = check_model_and_trace(model; error_on_failure=true);
ERROR: varname x used multiple times in model
```
"""
function check_model_and_trace(
model::Model, varinfo::AbstractVarInfo; error_on_failure=false
rng::Random.AbstractRNG, model::Model; error_on_failure=false
)
# Add debug accumulator to the VarInfo.
varinfo = DynamicPPL.setaccs!!(deepcopy(varinfo), (DebugAccumulator(error_on_failure),))

# Perform checks before evaluating the model.
issuccess = check_model_pre_evaluation(model)

# TODO(penelopeysm): Implement merge, etc. for DebugAccumulator, and then perform a
# check on the merged accumulator, rather than checking it in the accumulate_assume
# calls. That way we can also correctly support multi-threaded evaluation.
_, varinfo = DynamicPPL.evaluate!!(model, varinfo)
oavi = DynamicPPL.OnlyAccsVarInfo((DebugAccumulator(error_on_failure),))
init_strategy = InitFromPrior()
_, oavi = DynamicPPL.init!!(rng, model, oavi, init_strategy, UnlinkAll())

# Perform checks after evaluating the model.
debug_acc = DynamicPPL.getacc(varinfo, Val(_DEBUG_ACC_NAME))
debug_acc = DynamicPPL.getacc(oavi, Val(_DEBUG_ACC_NAME))
issuccess = issuccess && check_model_post_evaluation(debug_acc)

if !issuccess && error_on_failure
Expand All @@ -433,18 +429,26 @@ function check_model_and_trace(
trace = debug_acc.statements
return issuccess, trace
end
function check_model_and_trace(model::Model; error_on_failure=false)
return check_model_and_trace(
Random.default_rng(), model; error_on_failure=error_on_failure
)
end

"""
check_model(model::Model, varinfo::AbstractVarInfo; error_on_failure=false)
check_model(model::Model; error_on_failure=false)

Check that `model` is valid, warning about any potential issues (or erroring if
`error_on_failure` is `true`).

# Returns
- `issuccess::Bool`: Whether the model check succeeded.
"""
check_model(model::Model, varinfo::AbstractVarInfo; error_on_failure=false) =
first(check_model_and_trace(model, varinfo; error_on_failure=error_on_failure))
check_model(rng::Random.AbstractRNG, model::Model; error_on_failure=false) =
first(check_model_and_trace(rng, model; error_on_failure=error_on_failure))
function check_model(model::Model; error_on_failure=false)
return check_model(Random.default_rng(), model; error_on_failure=error_on_failure)
end

# Convenience method used to check if all elements in a list are the same.
function all_the_same(xs)
Expand Down Expand Up @@ -479,11 +483,8 @@ and checking if the model is consistent across runs.
function has_static_constraints(
rng::Random.AbstractRNG, model::Model; num_evals::Int=5, error_on_failure::Bool=false
)
new_model = DynamicPPL.contextualize(
model, InitContext(rng, InitFromPrior(), UnlinkAll())
)
results = map(1:num_evals) do _
check_model_and_trace(new_model, VarInfo(); error_on_failure=error_on_failure)
check_model_and_trace(rng, model; error_on_failure=error_on_failure)
end

# Extract the distributions and the corresponding bijectors for each run.
Expand Down
Loading