Skip to content

Make PrefixContext contain a varname rather than symbol #896

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: py/submodel-cond
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 56 additions & 34 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

**Breaking changes**

### Submodels
### Submodels: conditioning

Variables in a submodel can now be conditioned and fixed in a correct way.
See https://github.com/TuringLang/DynamicPPL.jl/issues/857 for a full illustration, but essentially it means you can now do this:
Expand All @@ -22,38 +22,7 @@ end
and the `inner.x` variable will be correctly conditioned.
(Previously, you would have to condition `inner()` with the variable `a.x`, meaning that you would need to know what prefix to use before you had actually prefixed it.)

### AD testing utilities

`DynamicPPL.TestUtils.AD.run_ad` now links the VarInfo by default.
To disable this, pass the `linked=false` keyword argument.
If the calculated value or gradient is incorrect, it also throws a `DynamicPPL.TestUtils.AD.ADIncorrectException` rather than a test failure.
This exception contains the actual and expected gradient so you can inspect it if needed; see the documentation for more information.
From a practical perspective, this means that if you need to add this to a test suite, you need to use `@test run_ad(...) isa Any` rather than just `run_ad(...)`.

### SimpleVarInfo linking / invlinking

Linking a linked SimpleVarInfo, or invlinking an unlinked SimpleVarInfo, now displays a warning instead of an error.

### VarInfo constructors

`VarInfo(vi::VarInfo, values)` has been removed. You can replace this directly with `unflatten(vi, values)` instead.

The `metadata` argument to `VarInfo([rng, ]model[, sampler, context, metadata])` has been removed.
If you were not using this argument (most likely), then there is no change needed.
If you were using the `metadata` argument to specify a blank `VarNamedVector`, then you should replace calls to `VarInfo` with `DynamicPPL.typed_vector_varinfo` instead (see 'Other changes' below).

The `UntypedVarInfo` constructor and type is no longer exported.
If you needed to construct one, you should now use `DynamicPPL.untyped_varinfo` instead.

The `TypedVarInfo` constructor and type is no longer exported.
The _type_ has been replaced with `DynamicPPL.NTVarInfo`.
The _constructor_ has been replaced with `DynamicPPL.typed_varinfo`.

Note that the exact kind of VarInfo returned by `VarInfo(rng, model, ...)` is an implementation detail.
Previously, it was guaranteed that this would always be a VarInfo whose metadata was a `NamedTuple` containing `Metadata` structs.
Going forward, this is no longer the case, and you should only assume that the returned object obeys the `AbstractVarInfo` interface.

### VarName prefixing behaviour
### Submodel prefixing

The way in which VarNames in submodels are prefixed has been changed.
This is best explained through an example.
Expand Down Expand Up @@ -95,9 +64,62 @@ outer() | (@varname(var"a.x") => 1.0,)
outer() | (a.x=1.0,)
```

If you are sampling from a model with submodels, this doesn't affect the way you interact with the `MCMCChains.Chains` object, because VarNames are converted into Symbols when stored in the chain.
In a similar way, if the variable on the left-hand side of your tilde statement is not just a single identifier, any fields or indices it accesses are now properly respected.
Consider the following setup:

```julia
using DynamicPPL, Distributions
@model inner() = x ~ Normal()
@model function outer()
a = Vector{Float64}(undef, 1)
a[1] ~ to_submodel(inner())
return a
end
```

In this case, the variable sampled is actually the `x` field of the first element of `a`:

```julia
julia> only(keys(VarInfo(outer()))) == @varname(a[1].x)
true
```

Before this version, it used to be a single variable called `var"a[1].x"`.

Note that if you are sampling from a model with submodels, this doesn't affect the way you interact with the `MCMCChains.Chains` object, because VarNames are converted into Symbols when stored in the chain.
(This behaviour will likely be changed in the future, in that Chains should be indexable by VarNames and not just Symbols, but that has not been implemented yet.)

### AD testing utilities

`DynamicPPL.TestUtils.AD.run_ad` now links the VarInfo by default.
To disable this, pass the `linked=false` keyword argument.
If the calculated value or gradient is incorrect, it also throws a `DynamicPPL.TestUtils.AD.ADIncorrectException` rather than a test failure.
This exception contains the actual and expected gradient so you can inspect it if needed; see the documentation for more information.
From a practical perspective, this means that if you need to add this to a test suite, you need to use `@test run_ad(...) isa Any` rather than just `run_ad(...)`.

### SimpleVarInfo linking / invlinking

Linking a linked SimpleVarInfo, or invlinking an unlinked SimpleVarInfo, now displays a warning instead of an error.

### VarInfo constructors

`VarInfo(vi::VarInfo, values)` has been removed. You can replace this directly with `unflatten(vi, values)` instead.

The `metadata` argument to `VarInfo([rng, ]model[, sampler, context, metadata])` has been removed.
If you were not using this argument (most likely), then there is no change needed.
If you were using the `metadata` argument to specify a blank `VarNamedVector`, then you should replace calls to `VarInfo` with `DynamicPPL.typed_vector_varinfo` instead (see 'Other changes' below).

The `UntypedVarInfo` constructor and type is no longer exported.
If you needed to construct one, you should now use `DynamicPPL.untyped_varinfo` instead.

The `TypedVarInfo` constructor and type is no longer exported.
The _type_ has been replaced with `DynamicPPL.NTVarInfo`.
The _constructor_ has been replaced with `DynamicPPL.typed_varinfo`.

Note that the exact kind of VarInfo returned by `VarInfo(rng, model, ...)` is an implementation detail.
Previously, it was guaranteed that this would always be a VarInfo whose metadata was a `NamedTuple` containing `Metadata` structs.
Going forward, this is no longer the case, and you should only assume that the returned object obeys the `AbstractVarInfo` interface.

**Other changes**

While these are technically breaking, they are only internal changes and do not affect the public API.
Expand Down
19 changes: 10 additions & 9 deletions docs/src/internals/submodel_condition.md
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,10 @@ Putting all of the information so far together, what it means is that if we have
using DynamicPPL: PrefixContext, ConditionContext, DefaultContext

inner_ctx_with_outer_cond = ConditionContext(
Dict(@varname(a.x) => 1.0), PrefixContext{:a}(DefaultContext())
Dict(@varname(a.x) => 1.0), PrefixContext(@varname(a))
)
inner_ctx_with_inner_cond = PrefixContext{:a}(
ConditionContext(Dict(@varname(x) => 1.0), DefaultContext())
inner_ctx_with_inner_cond = PrefixContext(
@varname(a), ConditionContext(Dict(@varname(x) => 1.0))
)
```

Expand Down Expand Up @@ -252,10 +252,11 @@ The general strategy that we adopt is similar to above.
Following the principle that `PrefixContext` should be nested inside the outer context, but outside the inner submodel's context, we can infer that the correct context inside `charlie` should be:

```@example
big_ctx = PrefixContext{:a}(
big_ctx = PrefixContext(
@varname(a),
ConditionContext(
Dict(@varname(b.y) => 1.0),
PrefixContext{:b}(ConditionContext(Dict(@varname(x) => 1.0))),
PrefixContext(@varname(b), ConditionContext(Dict(@varname(x) => 1.0))),
),
)
```
Expand All @@ -280,9 +281,9 @@ end
function myprefix(::IsParent, ctx::AbstractContext, vn::VarName)
return myprefix(childcontext(ctx), vn)
end
function myprefix(ctx::DynamicPPL.PrefixContext{Prefix}, vn::VarName) where {Prefix}
function myprefix(ctx::DynamicPPL.PrefixContext, vn::VarName)
# The functionality to actually manipulate the VarNames is in AbstractPPL
new_vn = AbstractPPL.prefix(vn, VarName{Prefix}())
new_vn = AbstractPPL.prefix(vn, ctx.vn_prefix)
# Then pass to the child context
return myprefix(childcontext(ctx), new_vn)
end
Expand All @@ -295,11 +296,11 @@ This implementation clearly is not correct, because it applies the _inner_ `Pref
The right way to implement `myprefix` is to, essentially, reverse the order of two lines above:

```@example
function myprefix(ctx::DynamicPPL.PrefixContext{Prefix}, vn::VarName) where {Prefix}
function myprefix(ctx::DynamicPPL.PrefixContext, vn::VarName)
# Pass to the child context first
new_vn = myprefix(childcontext(ctx), vn)
# Then apply this context's prefix
return AbstractPPL.prefix(new_vn, VarName{Prefix}())
return AbstractPPL.prefix(new_vn, ctx.vn_prefix)
end

myprefix(big_ctx, @varname(x))
Expand Down
2 changes: 1 addition & 1 deletion src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ function tilde_assume!!(context, right, vn, vi)
# change in the future.
if should_auto_prefix(right)
dppl_model = right.model.model # This isa DynamicPPL.Model
prefixed_submodel_context = PrefixContext{Symbol(vn)}(dppl_model.context)
prefixed_submodel_context = PrefixContext(vn, dppl_model.context)
new_dppl_model = contextualize(dppl_model, prefixed_submodel_context)
right = to_submodel(new_dppl_model, true)
end
Expand Down
69 changes: 41 additions & 28 deletions src/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -237,36 +237,43 @@
end

"""
PrefixContext{Prefix}(context)
PrefixContext(vn::VarName[, context::AbstractContext])
PrefixContext(vn::Val{sym}[, context::AbstractContext]) where {sym}

Create a context that allows you to use the wrapped `context` when running the model and
adds the `Prefix` to all parameters.
prefixes all parameters with the VarName `vn`.

`PrefixContext(Val(:a), context)` is equivalent to `PrefixContext(@varname(a), context)`.
If `context` is not provided, it defaults to `DefaultContext()`.

This context is useful in nested models to ensure that the names of the parameters are
unique.

See also: [`to_submodel`](@ref)
"""
struct PrefixContext{Prefix,C} <: AbstractContext
struct PrefixContext{Tvn<:VarName,C<:AbstractContext} <: AbstractContext
vn_prefix::Tvn
context::C
end
function PrefixContext{Prefix}(context::AbstractContext) where {Prefix}
return PrefixContext{Prefix,typeof(context)}(context)
PrefixContext(vn::VarName) = PrefixContext(vn, DefaultContext())
function PrefixContext(::Val{sym}, context::AbstractContext) where {sym}
return PrefixContext(VarName{sym}(), context)
end
PrefixContext(::Val{sym}) where {sym} = PrefixContext(VarName{sym}())

Check warning on line 262 in src/contexts.jl

View check run for this annotation

Codecov / codecov/patch

src/contexts.jl#L262

Added line #L262 was not covered by tests

NodeTrait(::PrefixContext) = IsParent()
childcontext(context::PrefixContext) = context.context
function setchildcontext(::PrefixContext{Prefix}, child) where {Prefix}
return PrefixContext{Prefix}(child)
function setchildcontext(ctx::PrefixContext, child::AbstractContext)
return PrefixContext(ctx.vn_prefix, child)
end

"""
prefix(ctx::AbstractContext, vn::VarName)

Apply the prefixes in the context `ctx` to the variable name `vn`.
"""
function prefix(ctx::PrefixContext{Prefix}, vn::VarName) where {Prefix}
return AbstractPPL.prefix(prefix(childcontext(ctx), vn), VarName{Prefix}())
function prefix(ctx::PrefixContext, vn::VarName)
return AbstractPPL.prefix(prefix(childcontext(ctx), vn), ctx.vn_prefix)
end
function prefix(ctx::AbstractContext, vn::VarName)
return prefix(NodeTrait(ctx), ctx, vn)
Expand Down Expand Up @@ -295,14 +302,13 @@
_do_ need to modify them, then you may need to use
`prefix_cond_and_fixed_variables` instead.
"""
function prefix_and_strip_contexts(ctx::PrefixContext{Prefix}, vn::VarName) where {Prefix}
function prefix_and_strip_contexts(ctx::PrefixContext, vn::VarName)
child_context = childcontext(ctx)
# vn_prefixed contains the prefixes from all lower levels
vn_prefixed, child_context_without_prefixes = prefix_and_strip_contexts(
child_context, vn
)
return AbstractPPL.prefix(vn_prefixed, VarName{Prefix}()),
child_context_without_prefixes
return AbstractPPL.prefix(vn_prefixed, ctx.vn_prefix), child_context_without_prefixes
end
function prefix_and_strip_contexts(ctx::AbstractContext, vn::VarName)
return prefix_and_strip_contexts(NodeTrait(ctx), ctx, vn)
Expand All @@ -314,11 +320,16 @@
end

"""
prefix(model::Model, x)

Return `model` but with all random variables prefixed by `x`.
prefix(model::Model, x::VarName)
prefix(model::Model, x::Val{sym})
prefix(model::Model, x::Any)

If `x` is known at compile-time, use `Val{x}()` to avoid runtime overheads for prefixing.
Return `model` but with all random variables prefixed by `x`, where `x` is either:
- a `VarName` (e.g. `@varname(a)`),
- a `Val{sym}` (e.g. `Val(:a)`), or
- for any other type, `x` is converted to a Symbol and then to a `VarName`. Note that
this will introduce runtime overheads so is not recommended unless absolutely
necessary.

# Examples

Expand All @@ -328,17 +339,19 @@
julia> @model demo() = x ~ Dirac(1)
demo (generic function with 2 methods)

julia> rand(prefix(demo(), :my_prefix))
julia> rand(prefix(demo(), @varname(my_prefix)))
(var"my_prefix.x" = 1,)

julia> # One can also use `Val` to avoid runtime overheads.
rand(prefix(demo(), Val(:my_prefix)))
julia> rand(prefix(demo(), Val(:my_prefix)))
(var"my_prefix.x" = 1,)
```
"""
prefix(model::Model, x) = contextualize(model, PrefixContext{Symbol(x)}(model.context))
function prefix(model::Model, ::Val{x}) where {x}
return contextualize(model, PrefixContext{Symbol(x)}(model.context))
prefix(model::Model, x::VarName) = contextualize(model, PrefixContext(x, model.context))
function prefix(model::Model, x::Val{sym}) where {sym}
return contextualize(model, PrefixContext(VarName{sym}(), model.context))
end
function prefix(model::Model, x)
return contextualize(model, PrefixContext(VarName{Symbol(x)}(), model.context))
end

"""
Expand Down Expand Up @@ -426,7 +439,7 @@
function hasconditioned_nested(::IsParent, context, vn)
return hasconditioned(context, vn) || hasconditioned_nested(childcontext(context), vn)
end
function hasconditioned_nested(context::PrefixContext{Prefix}, vn) where {Prefix}
function hasconditioned_nested(context::PrefixContext, vn)
return hasconditioned_nested(collapse_prefix_stack(context), vn)
end

Expand All @@ -444,7 +457,7 @@
function getconditioned_nested(::IsLeaf, context, vn)
return error("context $(context) does not contain value for $vn")
end
function getconditioned_nested(context::PrefixContext{Prefix}, vn) where {Prefix}
function getconditioned_nested(context::PrefixContext, vn)
return getconditioned_nested(collapse_prefix_stack(context), vn)
end
function getconditioned_nested(::IsParent, context, vn)
Expand Down Expand Up @@ -715,13 +728,13 @@
```jldoctest
julia> using DynamicPPL: collapse_prefix_stack

julia> c1 = PrefixContext{:a}(ConditionContext((x=1, )));
julia> c1 = PrefixContext(@varname(a), ConditionContext((x=1, )));

julia> collapse_prefix_stack(c1)
ConditionContext(Dict(a.x => 1), DefaultContext())

julia> # Here, `x` gets prefixed only with `a`, whereas `y` is prefixed with both.
c2 = PrefixContext{:a}(ConditionContext((x=1, ), PrefixContext{:b}(ConditionContext((y=2,)))));
c2 = PrefixContext(@varname(a), ConditionContext((x=1, ), PrefixContext(@varname(b), ConditionContext((y=2,)))));

julia> collapsed = collapse_prefix_stack(c2);

Expand All @@ -733,14 +746,14 @@
(1, 2)
```
"""
function collapse_prefix_stack(context::PrefixContext{Prefix}) where {Prefix}
function collapse_prefix_stack(context::PrefixContext)
# Collapse the child context (thus applying any inner prefixes first)
collapsed = collapse_prefix_stack(childcontext(context))
# Prefix any conditioned variables with the current prefix
# Note: prefix_conditioned_variables is O(N) in the depth of the context stack.
# So is this function. In the worst case scenario, this is O(N^2) in the
# depth of the context stack.
return prefix_cond_and_fixed_variables(collapsed, VarName{Prefix}())
return prefix_cond_and_fixed_variables(collapsed, context.vn_prefix)
end
function collapse_prefix_stack(context::AbstractContext)
return collapse_prefix_stack(NodeTrait(collapse_prefix_stack, context), context)
Expand Down
8 changes: 4 additions & 4 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ julia> # Nested ones also work.
# (Note that `PrefixContext` also prefixes the variables of any
# ConditionContext that is _inside_ it; because of this, the type of the
# container has to be broadened to a `Dict`.)
cm = condition(contextualize(m, PrefixContext{:a}(ConditionContext((m=1.0,)))), x=100.0);
cm = condition(contextualize(m, PrefixContext(@varname(a), ConditionContext((m=1.0,)))), x=100.0);
julia> Set(keys(conditioned(cm))) == Set([@varname(a.m), @varname(x)])
true
Expand All @@ -441,7 +441,7 @@ julia> # Since we conditioned on `a.m`, it is not treated as a random variable.
a.x
julia> # We can also condition on `a.m` _outside_ of the PrefixContext:
cm = condition(contextualize(m, PrefixContext{:a}(DefaultContext())), (@varname(a.m) => 1.0));
cm = condition(contextualize(m, PrefixContext(@varname(a))), (@varname(a.m) => 1.0));
julia> conditioned(cm)
Dict{VarName{:a, Accessors.PropertyLens{:m}}, Float64} with 1 entry:
Expand Down Expand Up @@ -769,7 +769,7 @@ julia> # Returns all the variables we have fixed on + their values.
(x = 100.0, m = 1.0)
julia> # The rest of this is the same as the `condition` example above.
cm = fix(contextualize(m, PrefixContext{:a}(fix(m=1.0))), x=100.0);
cm = fix(contextualize(m, PrefixContext(@varname(a), fix(m=1.0))), x=100.0);
julia> Set(keys(fixed(cm))) == Set([@varname(a.m), @varname(x)])
true
Expand All @@ -779,7 +779,7 @@ julia> keys(VarInfo(cm))
a.x
julia> # We can also condition on `a.m` _outside_ of the PrefixContext:
cm = fix(contextualize(m, PrefixContext{:a}(DefaultContext())), (@varname(a.m) => 1.0));
cm = fix(contextualize(m, PrefixContext(@varname(a))), (@varname(a.m) => 1.0));
julia> fixed(cm)
Dict{VarName{:a, Accessors.PropertyLens{:m}}, Float64} with 1 entry:
Expand Down
4 changes: 2 additions & 2 deletions src/submodel_macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -223,12 +223,12 @@ end
prefix_submodel_context(prefix, left, ctx) = prefix_submodel_context(prefix, ctx)
function prefix_submodel_context(prefix, ctx)
# E.g. `prefix="asd[$i]"` or `prefix=asd` with `asd` to be evaluated.
return :($(PrefixContext){$(Symbol)($(esc(prefix)))}($ctx))
return :($(PrefixContext)($(Val)($(Symbol)($(esc(prefix)))), $ctx))
end

function prefix_submodel_context(prefix::Union{AbstractString,Symbol}, ctx)
# E.g. `prefix="asd"`.
return :($(PrefixContext){$(esc(Meta.quot(Symbol(prefix))))}($ctx))
return :($(PrefixContext)($(esc(Meta.quot(Val(Symbol(prefix))))), $ctx))
end

function prefix_submodel_context(prefix::Bool, ctx)
Expand Down
Loading
Loading