Skip to content

Commit 32a06f5

Browse files
committed
Make PrefixContext contain a varname rather than symbol
1 parent b545a93 commit 32a06f5

File tree

8 files changed

+152
-107
lines changed

8 files changed

+152
-107
lines changed

Diff for: HISTORY.md

+56-34
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
**Breaking changes**
66

7-
### Submodels
7+
### Submodels: conditioning
88

99
Variables in a submodel can now be conditioned and fixed in a correct way.
1010
See https://github.com/TuringLang/DynamicPPL.jl/issues/857 for a full illustration, but essentially it means you can now do this:
@@ -22,38 +22,7 @@ end
2222
and the `inner.x` variable will be correctly conditioned.
2323
(Previously, you would have to condition `inner()` with the variable `a.x`, meaning that you would need to know what prefix to use before you had actually prefixed it.)
2424

25-
### AD testing utilities
26-
27-
`DynamicPPL.TestUtils.AD.run_ad` now links the VarInfo by default.
28-
To disable this, pass the `linked=false` keyword argument.
29-
If the calculated value or gradient is incorrect, it also throws a `DynamicPPL.TestUtils.AD.ADIncorrectException` rather than a test failure.
30-
This exception contains the actual and expected gradient so you can inspect it if needed; see the documentation for more information.
31-
From a practical perspective, this means that if you need to add this to a test suite, you need to use `@test run_ad(...) isa Any` rather than just `run_ad(...)`.
32-
33-
### SimpleVarInfo linking / invlinking
34-
35-
Linking a linked SimpleVarInfo, or invlinking an unlinked SimpleVarInfo, now displays a warning instead of an error.
36-
37-
### VarInfo constructors
38-
39-
`VarInfo(vi::VarInfo, values)` has been removed. You can replace this directly with `unflatten(vi, values)` instead.
40-
41-
The `metadata` argument to `VarInfo([rng, ]model[, sampler, context, metadata])` has been removed.
42-
If you were not using this argument (most likely), then there is no change needed.
43-
If you were using the `metadata` argument to specify a blank `VarNamedVector`, then you should replace calls to `VarInfo` with `DynamicPPL.typed_vector_varinfo` instead (see 'Other changes' below).
44-
45-
The `UntypedVarInfo` constructor and type is no longer exported.
46-
If you needed to construct one, you should now use `DynamicPPL.untyped_varinfo` instead.
47-
48-
The `TypedVarInfo` constructor and type is no longer exported.
49-
The _type_ has been replaced with `DynamicPPL.NTVarInfo`.
50-
The _constructor_ has been replaced with `DynamicPPL.typed_varinfo`.
51-
52-
Note that the exact kind of VarInfo returned by `VarInfo(rng, model, ...)` is an implementation detail.
53-
Previously, it was guaranteed that this would always be a VarInfo whose metadata was a `NamedTuple` containing `Metadata` structs.
54-
Going forward, this is no longer the case, and you should only assume that the returned object obeys the `AbstractVarInfo` interface.
55-
56-
### VarName prefixing behaviour
25+
### Submodel prefixing
5726

5827
The way in which VarNames in submodels are prefixed has been changed.
5928
This is best explained through an example.
@@ -95,9 +64,62 @@ outer() | (@varname(var"a.x") => 1.0,)
9564
outer() | (a.x=1.0,)
9665
```
9766

98-
If you are sampling from a model with submodels, this doesn't affect the way you interact with the `MCMCChains.Chains` object, because VarNames are converted into Symbols when stored in the chain.
67+
In a similar way, if the variable on the left-hand side of your tilde statement is not just a single identifier, any fields or indices it accesses are now properly respected.
68+
Consider the following setup:
69+
70+
```julia
71+
using DynamicPPL, Distributions
72+
@model inner() = x ~ Normal()
73+
@model function outer()
74+
a = Vector{Float64}(undef, 1)
75+
a[1] ~ to_submodel(inner())
76+
return a
77+
end
78+
```
79+
80+
In this case, the variable sampled is actually the `x` field of the first element of `a`:
81+
82+
```julia
83+
julia> only(keys(VarInfo(outer()))) == @varname(a[1].x)
84+
true
85+
```
86+
87+
Before this version, it used to be the a single variable called `var"a[1].x"`.
88+
89+
Note that if you are sampling from a model with submodels, this doesn't affect the way you interact with the `MCMCChains.Chains` object, because VarNames are converted into Symbols when stored in the chain.
9990
(This behaviour will likely be changed in the future, in that Chains should be indexable by VarNames and not just Symbols, but that has not been implemented yet.)
10091

92+
### AD testing utilities
93+
94+
`DynamicPPL.TestUtils.AD.run_ad` now links the VarInfo by default.
95+
To disable this, pass the `linked=false` keyword argument.
96+
If the calculated value or gradient is incorrect, it also throws a `DynamicPPL.TestUtils.AD.ADIncorrectException` rather than a test failure.
97+
This exception contains the actual and expected gradient so you can inspect it if needed; see the documentation for more information.
98+
From a practical perspective, this means that if you need to add this to a test suite, you need to use `@test run_ad(...) isa Any` rather than just `run_ad(...)`.
99+
100+
### SimpleVarInfo linking / invlinking
101+
102+
Linking a linked SimpleVarInfo, or invlinking an unlinked SimpleVarInfo, now displays a warning instead of an error.
103+
104+
### VarInfo constructors
105+
106+
`VarInfo(vi::VarInfo, values)` has been removed. You can replace this directly with `unflatten(vi, values)` instead.
107+
108+
The `metadata` argument to `VarInfo([rng, ]model[, sampler, context, metadata])` has been removed.
109+
If you were not using this argument (most likely), then there is no change needed.
110+
If you were using the `metadata` argument to specify a blank `VarNamedVector`, then you should replace calls to `VarInfo` with `DynamicPPL.typed_vector_varinfo` instead (see 'Other changes' below).
111+
112+
The `UntypedVarInfo` constructor and type is no longer exported.
113+
If you needed to construct one, you should now use `DynamicPPL.untyped_varinfo` instead.
114+
115+
The `TypedVarInfo` constructor and type is no longer exported.
116+
The _type_ has been replaced with `DynamicPPL.NTVarInfo`.
117+
The _constructor_ has been replaced with `DynamicPPL.typed_varinfo`.
118+
119+
Note that the exact kind of VarInfo returned by `VarInfo(rng, model, ...)` is an implementation detail.
120+
Previously, it was guaranteed that this would always be a VarInfo whose metadata was a `NamedTuple` containing `Metadata` structs.
121+
Going forward, this is no longer the case, and you should only assume that the returned object obeys the `AbstractVarInfo` interface.
122+
101123
**Other changes**
102124

103125
While these are technically breaking, they are only internal changes and do not affect the public API.

Diff for: docs/src/internals/submodel_condition.md

+10-9
Original file line numberDiff line numberDiff line change
@@ -181,10 +181,10 @@ Putting all of the information so far together, what it means is that if we have
181181
using DynamicPPL: PrefixContext, ConditionContext, DefaultContext
182182
183183
inner_ctx_with_outer_cond = ConditionContext(
184-
Dict(@varname(a.x) => 1.0), PrefixContext{:a}(DefaultContext())
184+
Dict(@varname(a.x) => 1.0), PrefixContext(@varname(a))
185185
)
186-
inner_ctx_with_inner_cond = PrefixContext{:a}(
187-
ConditionContext(Dict(@varname(x) => 1.0), DefaultContext())
186+
inner_ctx_with_inner_cond = PrefixContext(
187+
@varname(a), ConditionContext(Dict(@varname(x) => 1.0))
188188
)
189189
```
190190

@@ -252,10 +252,11 @@ The general strategy that we adopt is similar to above.
252252
Following the principle that `PrefixContext` should be nested inside the outer context, but outside the inner submodel's context, we can infer that the correct context inside `charlie` should be:
253253

254254
```@example
255-
big_ctx = PrefixContext{:a}(
255+
big_ctx = PrefixContext(
256+
@varname(a),
256257
ConditionContext(
257258
Dict(@varname(b.y) => 1.0),
258-
PrefixContext{:b}(ConditionContext(Dict(@varname(x) => 1.0))),
259+
PrefixContext(@varname(b), ConditionContext(Dict(@varname(x) => 1.0))),
259260
),
260261
)
261262
```
@@ -280,9 +281,9 @@ end
280281
function myprefix(::IsParent, ctx::AbstractContext, vn::VarName)
281282
return myprefix(childcontext(ctx), vn)
282283
end
283-
function myprefix(ctx::DynamicPPL.PrefixContext{Prefix}, vn::VarName) where {Prefix}
284+
function myprefix(ctx::DynamicPPL.PrefixContext, vn::VarName)
284285
# The functionality to actually manipulate the VarNames is in AbstractPPL
285-
new_vn = AbstractPPL.prefix(vn, VarName{Prefix}())
286+
new_vn = AbstractPPL.prefix(vn, ctx.vn_prefix)
286287
# Then pass to the child context
287288
return myprefix(childcontext(ctx), new_vn)
288289
end
@@ -295,11 +296,11 @@ This implementation clearly is not correct, because it applies the _inner_ `Pref
295296
The right way to implement `myprefix` is to, essentially, reverse the order of two lines above:
296297

297298
```@example
298-
function myprefix(ctx::DynamicPPL.PrefixContext{Prefix}, vn::VarName) where {Prefix}
299+
function myprefix(ctx::DynamicPPL.PrefixContext, vn::VarName)
299300
# Pass to the child context first
300301
new_vn = myprefix(childcontext(ctx), vn)
301302
# Then apply this context's prefix
302-
return AbstractPPL.prefix(new_vn, VarName{Prefix}())
303+
return AbstractPPL.prefix(new_vn, ctx.vn_prefix)
303304
end
304305
305306
myprefix(big_ctx, @varname(x))

Diff for: src/context_implementations.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ function tilde_assume!!(context, right, vn, vi)
131131
# change in the future.
132132
if should_auto_prefix(right)
133133
dppl_model = right.model.model # This isa DynamicPPL.Model
134-
prefixed_submodel_context = PrefixContext{Symbol(vn)}(dppl_model.context)
134+
prefixed_submodel_context = PrefixContext(vn, dppl_model.context)
135135
new_dppl_model = contextualize(dppl_model, prefixed_submodel_context)
136136
right = to_submodel(new_dppl_model, true)
137137
end

Diff for: src/contexts.jl

+41-28
Original file line numberDiff line numberDiff line change
@@ -237,36 +237,43 @@ function setchildcontext(parent::MiniBatchContext, child)
237237
end
238238

239239
"""
240-
PrefixContext{Prefix}(context)
240+
PrefixContext(vn::VarName[, context::AbstractContext])
241+
PrefixContext(vn::Val{sym}[, context::AbstractContext]) where {sym}
241242
242243
Create a context that allows you to use the wrapped `context` when running the model and
243-
adds the `Prefix` to all parameters.
244+
prefixes all parameters with the VarName `vn`.
245+
246+
`PrefixContext(Val(:a), context)` is equivalent to `PrefixContext(@varname(a), context)`.
247+
If `context` is not provided, it defaults to `DefaultContext()`.
244248
245249
This context is useful in nested models to ensure that the names of the parameters are
246250
unique.
247251
248252
See also: [`to_submodel`](@ref)
249253
"""
250-
struct PrefixContext{Prefix,C} <: AbstractContext
254+
struct PrefixContext{Tvn<:VarName,C<:AbstractContext} <: AbstractContext
255+
vn_prefix::Tvn
251256
context::C
252257
end
253-
function PrefixContext{Prefix}(context::AbstractContext) where {Prefix}
254-
return PrefixContext{Prefix,typeof(context)}(context)
258+
PrefixContext(vn::VarName) = PrefixContext(vn, DefaultContext())
259+
function PrefixContext(::Val{sym}, context::AbstractContext) where {sym}
260+
return PrefixContext(VarName{sym}(), context)
255261
end
262+
PrefixContext(::Val{sym}) where {sym} = PrefixContext(VarName{sym}())
256263

257264
NodeTrait(::PrefixContext) = IsParent()
258265
childcontext(context::PrefixContext) = context.context
259-
function setchildcontext(::PrefixContext{Prefix}, child) where {Prefix}
260-
return PrefixContext{Prefix}(child)
266+
function setchildcontext(ctx::PrefixContext, child::AbstractContext)
267+
return PrefixContext(ctx.vn_prefix, child)
261268
end
262269

263270
"""
264271
prefix(ctx::AbstractContext, vn::VarName)
265272
266273
Apply the prefixes in the context `ctx` to the variable name `vn`.
267274
"""
268-
function prefix(ctx::PrefixContext{Prefix}, vn::VarName) where {Prefix}
269-
return AbstractPPL.prefix(prefix(childcontext(ctx), vn), VarName{Prefix}())
275+
function prefix(ctx::PrefixContext, vn::VarName)
276+
return AbstractPPL.prefix(prefix(childcontext(ctx), vn), ctx.vn_prefix)
270277
end
271278
function prefix(ctx::AbstractContext, vn::VarName)
272279
return prefix(NodeTrait(ctx), ctx, vn)
@@ -295,14 +302,13 @@ not_ need to modify any inner `ConditionContext`s and `FixedContext`s. If you
295302
_do_ need to modify them, then you may need to use
296303
`prefix_cond_and_fixed_variables` instead.
297304
"""
298-
function prefix_and_strip_contexts(ctx::PrefixContext{Prefix}, vn::VarName) where {Prefix}
305+
function prefix_and_strip_contexts(ctx::PrefixContext, vn::VarName)
299306
child_context = childcontext(ctx)
300307
# vn_prefixed contains the prefixes from all lower levels
301308
vn_prefixed, child_context_without_prefixes = prefix_and_strip_contexts(
302309
child_context, vn
303310
)
304-
return AbstractPPL.prefix(vn_prefixed, VarName{Prefix}()),
305-
child_context_without_prefixes
311+
return AbstractPPL.prefix(vn_prefixed, ctx.vn_prefix), child_context_without_prefixes
306312
end
307313
function prefix_and_strip_contexts(ctx::AbstractContext, vn::VarName)
308314
return prefix_and_strip_contexts(NodeTrait(ctx), ctx, vn)
@@ -314,11 +320,16 @@ function prefix_and_strip_contexts(::IsParent, ctx::AbstractContext, vn::VarName
314320
end
315321

316322
"""
317-
prefix(model::Model, x)
318-
319-
Return `model` but with all random variables prefixed by `x`.
323+
prefix(model::Model, x::VarName)
324+
prefix(model::Model, x::Val{sym})
325+
prefix(model::Model, x::Any)
320326
321-
If `x` is known at compile-time, use `Val{x}()` to avoid runtime overheads for prefixing.
327+
Return `model` but with all random variables prefixed by `x`, where `x` is either:
328+
- a `VarName` (e.g. `@varname(a)`),
329+
- a `Val{sym}` (e.g. `Val(:a)`), or
330+
- for any other type, `x` is converted to a Symbol and then to a `VarName`. Note that
331+
this will introduce runtime overheads so is not recommended unless absolutely
332+
necessary.
322333
323334
# Examples
324335
@@ -328,17 +339,19 @@ julia> using DynamicPPL: prefix
328339
julia> @model demo() = x ~ Dirac(1)
329340
demo (generic function with 2 methods)
330341
331-
julia> rand(prefix(demo(), :my_prefix))
342+
julia> rand(prefix(demo(), @varname(my_prefix)))
332343
(var"my_prefix.x" = 1,)
333344
334-
julia> # One can also use `Val` to avoid runtime overheads.
335-
rand(prefix(demo(), Val(:my_prefix)))
345+
julia> rand(prefix(demo(), Val(:my_prefix)))
336346
(var"my_prefix.x" = 1,)
337347
```
338348
"""
339-
prefix(model::Model, x) = contextualize(model, PrefixContext{Symbol(x)}(model.context))
340-
function prefix(model::Model, ::Val{x}) where {x}
341-
return contextualize(model, PrefixContext{Symbol(x)}(model.context))
349+
prefix(model::Model, x::VarName) = contextualize(model, PrefixContext(x, model.context))
350+
function prefix(model::Model, x::Val{sym}) where {sym}
351+
return contextualize(model, PrefixContext(VarName{sym}(), model.context))
352+
end
353+
function prefix(model::Model, x)
354+
return contextualize(model, PrefixContext(VarName{Symbol(x)}(), model.context))
342355
end
343356

344357
"""
@@ -426,7 +439,7 @@ hasconditioned_nested(::IsLeaf, context, vn) = hasconditioned(context, vn)
426439
function hasconditioned_nested(::IsParent, context, vn)
427440
return hasconditioned(context, vn) || hasconditioned_nested(childcontext(context), vn)
428441
end
429-
function hasconditioned_nested(context::PrefixContext{Prefix}, vn) where {Prefix}
442+
function hasconditioned_nested(context::PrefixContext, vn)
430443
return hasconditioned_nested(collapse_prefix_stack(context), vn)
431444
end
432445

@@ -444,7 +457,7 @@ end
444457
function getconditioned_nested(::IsLeaf, context, vn)
445458
return error("context $(context) does not contain value for $vn")
446459
end
447-
function getconditioned_nested(context::PrefixContext{Prefix}, vn) where {Prefix}
460+
function getconditioned_nested(context::PrefixContext, vn)
448461
return getconditioned_nested(collapse_prefix_stack(context), vn)
449462
end
450463
function getconditioned_nested(::IsParent, context, vn)
@@ -715,13 +728,13 @@ which explains this in much more detail.
715728
```jldoctest
716729
julia> using DynamicPPL: collapse_prefix_stack
717730
718-
julia> c1 = PrefixContext{:a}(ConditionContext((x=1, )));
731+
julia> c1 = PrefixContext(@varname(a), ConditionContext((x=1, )));
719732
720733
julia> collapse_prefix_stack(c1)
721734
ConditionContext(Dict(a.x => 1), DefaultContext())
722735
723736
julia> # Here, `x` gets prefixed only with `a`, whereas `y` is prefixed with both.
724-
c2 = PrefixContext{:a}(ConditionContext((x=1, ), PrefixContext{:b}(ConditionContext((y=2,)))));
737+
c2 = PrefixContext(@varname(a), ConditionContext((x=1, ), PrefixContext(@varname(b), ConditionContext((y=2,)))));
725738
726739
julia> collapsed = collapse_prefix_stack(c2);
727740
@@ -733,14 +746,14 @@ julia> # `collapsed` really looks something like this:
733746
(1, 2)
734747
```
735748
"""
736-
function collapse_prefix_stack(context::PrefixContext{Prefix}) where {Prefix}
749+
function collapse_prefix_stack(context::PrefixContext)
737750
# Collapse the child context (thus applying any inner prefixes first)
738751
collapsed = collapse_prefix_stack(childcontext(context))
739752
# Prefix any conditioned variables with the current prefix
740753
# Note: prefix_conditioned_variables is O(N) in the depth of the context stack.
741754
# So is this function. In the worst case scenario, this is O(N^2) in the
742755
# depth of the context stack.
743-
return prefix_cond_and_fixed_variables(collapsed, VarName{Prefix}())
756+
return prefix_cond_and_fixed_variables(collapsed, context.vn_prefix)
744757
end
745758
function collapse_prefix_stack(context::AbstractContext)
746759
return collapse_prefix_stack(NodeTrait(collapse_prefix_stack, context), context)

Diff for: src/model.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@ julia> # Nested ones also work.
429429
# (Note that `PrefixContext` also prefixes the variables of any
430430
# ConditionContext that is _inside_ it; because of this, the type of the
431431
# container has to be broadened to a `Dict`.)
432-
cm = condition(contextualize(m, PrefixContext{:a}(ConditionContext((m=1.0,)))), x=100.0);
432+
cm = condition(contextualize(m, PrefixContext(@varname(a), ConditionContext((m=1.0,)))), x=100.0);
433433
434434
julia> Set(keys(conditioned(cm))) == Set([@varname(a.m), @varname(x)])
435435
true
@@ -441,7 +441,7 @@ julia> # Since we conditioned on `a.m`, it is not treated as a random variable.
441441
a.x
442442
443443
julia> # We can also condition on `a.m` _outside_ of the PrefixContext:
444-
cm = condition(contextualize(m, PrefixContext{:a}(DefaultContext())), (@varname(a.m) => 1.0));
444+
cm = condition(contextualize(m, PrefixContext(@varname(a))), (@varname(a.m) => 1.0));
445445
446446
julia> conditioned(cm)
447447
Dict{VarName{:a, Accessors.PropertyLens{:m}}, Float64} with 1 entry:
@@ -769,7 +769,7 @@ julia> # Returns all the variables we have fixed on + their values.
769769
(x = 100.0, m = 1.0)
770770
771771
julia> # The rest of this is the same as the `condition` example above.
772-
cm = fix(contextualize(m, PrefixContext{:a}(fix(m=1.0))), x=100.0);
772+
cm = fix(contextualize(m, PrefixContext(@varname(a), fix(m=1.0))), x=100.0);
773773
774774
julia> Set(keys(fixed(cm))) == Set([@varname(a.m), @varname(x)])
775775
true
@@ -779,7 +779,7 @@ julia> keys(VarInfo(cm))
779779
a.x
780780
781781
julia> # We can also condition on `a.m` _outside_ of the PrefixContext:
782-
cm = fix(contextualize(m, PrefixContext{:a}(DefaultContext())), (@varname(a.m) => 1.0));
782+
cm = fix(contextualize(m, PrefixContext(@varname(a))), (@varname(a.m) => 1.0));
783783
784784
julia> fixed(cm)
785785
Dict{VarName{:a, Accessors.PropertyLens{:m}}, Float64} with 1 entry:

Diff for: src/submodel_macro.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -223,12 +223,12 @@ end
223223
prefix_submodel_context(prefix, left, ctx) = prefix_submodel_context(prefix, ctx)
224224
function prefix_submodel_context(prefix, ctx)
225225
# E.g. `prefix="asd[$i]"` or `prefix=asd` with `asd` to be evaluated.
226-
return :($(PrefixContext){$(Symbol)($(esc(prefix)))}($ctx))
226+
return :($(PrefixContext)($(Val)($(Symbol)($(esc(prefix)))), $ctx))
227227
end
228228

229229
function prefix_submodel_context(prefix::Union{AbstractString,Symbol}, ctx)
230230
# E.g. `prefix="asd"`.
231-
return :($(PrefixContext){$(esc(Meta.quot(Symbol(prefix))))}($ctx))
231+
return :($(PrefixContext)($(esc(Meta.quot(Val(Symbol(prefix))))), $ctx))
232232
end
233233

234234
function prefix_submodel_context(prefix::Bool, ctx)

0 commit comments

Comments
 (0)