Skip to content

Commit 751ae1d

Browse files
authored
Enable keyword arguments for particle methods (#2660)
> [!NOTE] > ~~This PR requires some changes to AdvancedPS. TuringLang/AdvancedPS.jl#118 This is merged > > ~~It also needs the following Libtask patch: TuringLang/Libtask.jl#198 This is merged > > ~~This PR also lacks tests; some should be added.~~ Tests added. This PR allows models with keyword arguments to be run with SMC / PG. Example: ```julia julia> using Turing julia> @model function m(y; n=0) x ~ Normal(n) y ~ Normal(x) end m (generic function with 2 methods) julia> mean(sample(m(5.0), PG(20), 1000)) [...] ERROR: Models with keyword arguments need special treatment to be used with particle methods. Please run: using Libtask; Libtask.@might_produce(m) before sampling from this model with particle methods. Stacktrace: [...] julia> using Libtask; Libtask.@might_produce(m) julia> mean(sample(m(5.0), PG(20), 1000)) Sampling 100%|███████████████████████████████████████████████████████████████████| Time: 0:00:05 Mean parameters mean Symbol Float64 x 2.7182 julia> mean(sample(m(5.0; n=10.0), PG(20), 1000)) Sampling 100%|███████████████████████████████████████████████████████████████████| Time: 0:00:04 Mean parameters mean Symbol Float64 x 7.4854 ``` Closes #2007.
1 parent 8762fa3 commit 751ae1d

8 files changed

Lines changed: 100 additions & 42 deletions

File tree

HISTORY.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,20 @@
1+
# 0.42.5
2+
3+
SMC and PG can now be used for models with keyword arguments, albeit with one requirement: the user must mark the model function as being able to produce.
4+
For example, if the model is
5+
6+
```julia
7+
@model foo(x; y) = a ~ Normal(x, y)
8+
```
9+
10+
then before samping from this with SMC or PG, you will have to run
11+
12+
```julia
13+
using Turing
14+
15+
@might_produce(foo)
16+
```
17+
118
# 0.42.4
219

320
Fixes a typo that caused NUTS to perform one less adaptation step than in versions prior to 0.41.

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Turing"
22
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
3-
version = "0.42.4"
3+
version = "0.42.5"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -52,7 +52,7 @@ AbstractPPL = "0.11, 0.12, 0.13"
5252
Accessors = "0.1"
5353
AdvancedHMC = "0.8.3"
5454
AdvancedMH = "0.8.9"
55-
AdvancedPS = "0.7"
55+
AdvancedPS = "0.7.2"
5656
AdvancedVI = "0.6"
5757
BangBang = "0.4.2"
5858
Bijectors = "0.14, 0.15"
@@ -65,7 +65,7 @@ DynamicHMC = "3.4"
6565
DynamicPPL = "0.39.1"
6666
EllipticalSliceSampling = "0.5, 1, 2"
6767
ForwardDiff = "0.10.3, 1"
68-
Libtask = "0.9.3"
68+
Libtask = "0.9.5"
6969
LinearAlgebra = "1"
7070
LogDensityProblems = "2"
7171
MCMCChains = "5, 6, 7"

docs/make.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ using DocumenterInterLinks
66
links = InterLinks(
77
"DynamicPPL" => "https://turinglang.org/DynamicPPL.jl/stable/",
88
"AbstractPPL" => "https://turinglang.org/AbstractPPL.jl/stable/",
9+
"Libtask" => "https://turinglang.org/Libtask.jl/stable/",
910
"LinearAlgebra" => "https://docs.julialang.org/en/v1/",
1011
"AbstractMCMC" => "https://turinglang.org/AbstractMCMC.jl/stable/",
1112
"ADTypes" => "https://sciml.github.io/ADTypes.jl/stable/",

docs/src/api.md

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ even though [`Prior()`](@ref) is actually defined in the `Turing.Inference` modu
4444
| `LogDensityFunction` | [`DynamicPPL.LogDensityFunction`](@extref) | A struct containing all information about how to evaluate a model. Mostly for advanced users |
4545
| `@addlogprob!` | [`DynamicPPL.@addlogprob!`](@extref) | Add arbitrary log-probability terms during model evaluation |
4646
| `setthreadsafe` | [`DynamicPPL.setthreadsafe`](@extref) | Mark a model as requiring threadsafe evaluation |
47+
| `might_produce` | [`Libtask.might_produce`](@extref) | Mark a method signature as potentially calling `Libtask.produce` |
48+
| `@might_produce` | [`Libtask.@might_produce`](@extref) | Mark a function name as potentially calling `Libtask.produce` |
4749

4850
### Inference
4951

@@ -110,19 +112,19 @@ Turing.jl provides several strategies to initialise parameters for models.
110112

111113
See the [docs of AdvancedVI.jl](https://turinglang.org/AdvancedVI.jl/stable/) for detailed usage and the [variational inference tutorial](https://turinglang.org/docs/tutorials/09-variational-inference/) for a basic walkthrough.
112114

113-
| Exported symbol | Documentation | Description |
114-
|:----------------------------- |:-------------------------------------------------------- |:------------------------------------------------------------------------------------------------------------------------------------------------- |
115-
| `vi` | [`Turing.vi`](@ref) | Perform variational inference |
116-
| `q_locationscale` | [`Turing.Variational.q_locationscale`](@ref) | Find a numerically non-degenerate initialization for a location-scale variational family |
117-
| `q_meanfield_gaussian` | [`Turing.Variational.q_meanfield_gaussian`](@ref) | Find a numerically non-degenerate initialization for a mean-field Gaussian family |
118-
| `q_fullrank_gaussian` | [`Turing.Variational.q_fullrank_gaussian`](@ref) | Find a numerically non-degenerate initialization for a full-rank Gaussian family |
119-
| `KLMinRepGradDescent` | [`Turing.Variational.KLMinRepGradDescent`](@ref) | KL divergence minimization via stochastic gradient descent with the reparameterization gradient |
120-
| `KLMinRepGradProxDescent` | [`Turing.Variational.KLMinRepGradProxDescent`](@ref) | KL divergence minimization via stochastic proximal gradient descent with the reparameterization gradient over location-scale variational families |
121-
| `KLMinScoreGradDescent` | [`Turing.Variational.KLMinScoreGradDescent`](@ref) | KL divergence minimization via stochastic gradient descent with the score gradient |
122-
| `KLMinWassFwdBwd` | [`Turing.Variational.KLMinWassFwdBwd`](@ref) | KL divergence minimization via Wasserstein proximal gradient descent |
123-
| `KLMinNaturalGradDescent` | [`Turing.Variational.KLMinNaturalGradDescent`](@ref) | KL divergence minimization via natural gradient descent |
124-
| `KLMinSqrtNaturalGradDescent` | [`Turing.Variational.KLMinSqrtNaturalGradDescent`](@ref) | KL divergence minimization via natural gradient descent in the square-root parameterization |
125-
| `FisherMinBatchMatch` | [`Turing.Variational.FisherMinBatchMatch`](@ref) | Covariance-weighted Fisher divergence minimization via the batch-and-match algorithm |
115+
| Exported symbol | Documentation | Description |
116+
|:----------------------------- |:--------------------------------------------------- |:------------------------------------------------------------------------------------------------------------------------------------------------- |
117+
| `vi` | [`Turing.vi`](@ref) | Perform variational inference |
118+
| `q_locationscale` | [`Turing.Variational.q_locationscale`](@ref) | Find a numerically non-degenerate initialization for a location-scale variational family |
119+
| `q_meanfield_gaussian` | [`Turing.Variational.q_meanfield_gaussian`](@ref) | Find a numerically non-degenerate initialization for a mean-field Gaussian family |
120+
| `q_fullrank_gaussian` | [`Turing.Variational.q_fullrank_gaussian`](@ref) | Find a numerically non-degenerate initialization for a full-rank Gaussian family |
121+
| `KLMinRepGradDescent` | [`AdvancedVI.KLMinRepGradDescent`](@extref) | KL divergence minimization via stochastic gradient descent with the reparameterization gradient |
122+
| `KLMinRepGradProxDescent` | [`AdvancedVI.KLMinRepGradProxDescent`](@extref) | KL divergence minimization via stochastic proximal gradient descent with the reparameterization gradient over location-scale variational families |
123+
| `KLMinScoreGradDescent` | [`AdvancedVI.KLMinScoreGradDescent`](@extref) | KL divergence minimization via stochastic gradient descent with the score gradient |
124+
| `KLMinWassFwdBwd` | [`AdvancedVI.KLMinWassFwdBwd`](@extref) | KL divergence minimization via Wasserstein proximal gradient descent |
125+
| `KLMinNaturalGradDescent` | [`AdvancedVI.KLMinNaturalGradDescent`](@extref) | KL divergence minimization via natural gradient descent |
126+
| `KLMinSqrtNaturalGradDescent` | [`AdvancedVI.KLMinSqrtNaturalGradDescent`](@extref) | KL divergence minimization via natural gradient descent in the square-root parameterization |
127+
| `FisherMinBatchMatch` | [`AdvancedVI.FisherMinBatchMatch`](@extref) | Covariance-weighted Fisher divergence minimization via the batch-and-match algorithm |
126128

127129
### Automatic differentiation types
128130

src/Turing.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ using DynamicPPL:
8080
setthreadsafe
8181
using StatsBase: predict
8282
using OrderedCollections: OrderedDict
83+
using Libtask: might_produce, @might_produce
8384

8485
# Turing essentials - modelling macros and inference algorithms
8586
export
@@ -172,6 +173,9 @@ export
172173
MAP,
173174
MLE,
174175
# Chain save/resume
175-
loadstate
176+
loadstate,
177+
# kwargs in SMC
178+
might_produce,
179+
@might_produce
176180

177181
end

src/mcmc/particle_mcmc.jl

Lines changed: 41 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,23 +19,22 @@ struct ParticleMCMCContext{R<:AbstractRNG} <: DynamicPPL.AbstractContext
1919
rng::R
2020
end
2121

22-
struct TracedModel{V<:AbstractVarInfo,M<:Model,E<:Tuple} <: AdvancedPS.AbstractGenericModel
22+
struct TracedModel{V<:AbstractVarInfo,M<:Model,T<:Tuple,NT<:NamedTuple} <:
23+
AdvancedPS.AbstractTuringLibtaskModel
2324
model::M
2425
varinfo::V
25-
evaluator::E
2626
resample::Bool
27+
fargs::T
28+
kwargs::NT
2729
end
2830

2931
function TracedModel(
3032
model::Model, varinfo::AbstractVarInfo, rng::Random.AbstractRNG, resample::Bool
3133
)
3234
model = DynamicPPL.setleafcontext(model, ParticleMCMCContext(rng))
3335
args, kwargs = DynamicPPL.make_evaluate_args_and_kwargs(model, varinfo)
34-
isempty(kwargs) || error(
35-
"Particle sampling methods do not currently support models with keyword arguments.",
36-
)
37-
evaluator = (model.f, args...)
38-
return TracedModel(model, varinfo, evaluator, resample)
36+
fargs = (model.f, args...)
37+
return TracedModel(model, varinfo, resample, fargs, kwargs)
3938
end
4039

4140
function AdvancedPS.advance!(
@@ -53,16 +52,16 @@ function AdvancedPS.delete_retained!(trace::TracedModel)
5352
# In such a case, we need to ensure that when we continue sampling (i.e.
5453
# the next time we hit tilde_assume!!), we don't use the values in the
5554
# reference particle but rather sample new values.
56-
return TracedModel(trace.model, trace.varinfo, trace.evaluator, true)
55+
return TracedModel(trace.model, trace.varinfo, true, trace.fargs, trace.kwargs)
5756
end
5857

5958
function AdvancedPS.reset_model(trace::TracedModel)
6059
return trace
6160
end
6261

63-
function Libtask.TapedTask(taped_globals, model::TracedModel; kwargs...)
62+
function Libtask.TapedTask(taped_globals, model::TracedModel)
6463
return Libtask.TapedTask(
65-
taped_globals, model.evaluator[1], model.evaluator[2:end]...; kwargs...
64+
taped_globals, model.fargs[1], model.fargs[2:end]...; model.kwargs...
6665
)
6766
end
6867

@@ -124,6 +123,7 @@ function AbstractMCMC.sample(
124123
)
125124
check_model && _check_model(model, sampler)
126125
error_if_threadsafe_eval(model)
126+
check_model_kwargs(model)
127127
# need to add on the `nparticles` keyword argument for `initialstep` to make use of
128128
return AbstractMCMC.mcmcsample(
129129
rng,
@@ -138,6 +138,28 @@ function AbstractMCMC.sample(
138138
)
139139
end
140140

141+
function check_model_kwargs(model::DynamicPPL.Model)
142+
if !isempty(model.defaults)
143+
# If there are keyword arguments, we need to check that the user has
144+
# accounted for this by overloading `might_produce`.
145+
might_produce = Libtask.might_produce(typeof((Core.kwcall, NamedTuple(), model.f)))
146+
if !might_produce
147+
io = IOBuffer()
148+
ctx = IOContext(io, :color => true)
149+
print(
150+
ctx,
151+
"Models with keyword arguments need special treatment to be used" *
152+
" with particle methods. Please run:\n\n",
153+
)
154+
printstyled(
155+
ctx, " Turing.@might_produce($(model.f))"; bold=true, color=:blue
156+
)
157+
print(ctx, "\n\nbefore sampling from this model with particle methods.\n")
158+
error(String(take!(io)))
159+
end
160+
end
161+
end
162+
141163
function Turing.Inference.initialstep(
142164
rng::AbstractRNG,
143165
model::DynamicPPL.Model,
@@ -146,6 +168,7 @@ function Turing.Inference.initialstep(
146168
nparticles::Int,
147169
kwargs...,
148170
)
171+
check_model_kwargs(model)
149172
# Reset the VarInfo.
150173
vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator())
151174
vi = DynamicPPL.empty!!(vi)
@@ -254,6 +277,7 @@ function Turing.Inference.initialstep(
254277
rng::AbstractRNG, model::DynamicPPL.Model, spl::PG, vi::AbstractVarInfo; kwargs...
255278
)
256279
error_if_threadsafe_eval(model)
280+
check_model_kwargs(model)
257281
vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator())
258282

259283
# Create a new set of particles
@@ -495,7 +519,7 @@ end
495519
# details of the compiler, we set a bunch of methods as might_produce = true. We start with
496520
# adding to ProduceLogLikelihoodAccumulator, which is what calls `produce`, and go up the
497521
# call stack.
498-
Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.accloglikelihood!!),Vararg}}) = true
522+
Libtask.@might_produce(DynamicPPL.accloglikelihood!!)
499523
function Libtask.might_produce(
500524
::Type{
501525
<:Tuple{
@@ -507,15 +531,11 @@ function Libtask.might_produce(
507531
)
508532
return true
509533
end
510-
function Libtask.might_produce(
511-
::Type{<:Tuple{typeof(DynamicPPL.accumulate_observe!!),Vararg}}
512-
)
513-
return true
514-
end
515-
Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_observe!!),Vararg}}) = true
516-
# Could the next two could have tighter type bounds on the arguments, namely a GibbsContext?
534+
Libtask.@might_produce(DynamicPPL.accumulate_observe!!)
535+
Libtask.@might_produce(DynamicPPL.tilde_observe!!)
536+
# Could tilde_assume!! have tighter type bounds on the arguments, namely a GibbsContext?
517537
# That's the only thing that makes tilde_assume calls result in tilde_observe calls.
518-
Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_assume!!),Vararg}}) = true
519-
Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.evaluate!!),Vararg}}) = true
520-
Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.init!!),Vararg}}) = true
538+
Libtask.@might_produce(DynamicPPL.tilde_assume!!)
539+
Libtask.@might_produce(DynamicPPL.evaluate!!)
540+
Libtask.@might_produce(DynamicPPL.init!!)
521541
Libtask.might_produce(::Type{<:Tuple{<:DynamicPPL.Model,Vararg}}) = true

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ ADTypes = "1"
4343
AbstractMCMC = "5.9"
4444
AbstractPPL = "0.11, 0.12, 0.13"
4545
AdvancedMH = "0.8.9"
46-
AdvancedPS = "0.7"
46+
AdvancedPS = "0.7.2"
4747
AdvancedVI = "0.6"
4848
Aqua = "0.8"
4949
BangBang = "0.4"

test/mcmc/particle_mcmc.jl

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,9 +162,23 @@ end
162162
end
163163

164164
# https://github.com/TuringLang/Turing.jl/issues/2007
165-
@testset "keyword arguments not supported" begin
166-
@model kwarg_demo(; x=2) = return x
167-
@test_throws ErrorException sample(kwarg_demo(), PG(1), 10)
165+
@testset "keyword argument handling" begin
166+
@model function kwarg_demo(y; n=0.0)
167+
x ~ Normal(n)
168+
return y ~ Normal(x)
169+
end
170+
@test_throws "Models with keyword arguments" sample(kwarg_demo(5.0), PG(20), 10)
171+
172+
# Check that enabling `might_produce` does allow sampling
173+
@might_produce kwarg_demo
174+
chain = sample(StableRNG(468), kwarg_demo(5.0), PG(20), 1000)
175+
@test chain isa MCMCChains.Chains
176+
@test mean(chain[:x]) 2.5 atol = 0.2
177+
178+
# Check that the keyword argument's value is respected
179+
chain2 = sample(StableRNG(468), kwarg_demo(5.0; n=10.0), PG(20), 1000)
180+
@test chain2 isa MCMCChains.Chains
181+
@test mean(chain2[:x]) 7.5 atol = 0.2
168182
end
169183

170184
@testset "refuses to run threadsafe eval" begin

0 commit comments

Comments
 (0)