You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Enable keyword arguments for particle methods (#2660)
> [!NOTE]
> ~~This PR requires some changes to AdvancedPS.
TuringLang/AdvancedPS.jl#118 This is merged
>
> ~~It also needs the following Libtask patch:
TuringLang/Libtask.jl#198 This is merged
>
> ~~This PR also lacks tests; some should be added.~~ Tests added.
This PR allows models with keyword arguments to be run with SMC / PG.
Example:
```julia
julia> using Turing
julia> @model function m(y; n=0)
x ~ Normal(n)
y ~ Normal(x)
end
m (generic function with 2 methods)
julia> mean(sample(m(5.0), PG(20), 1000))
[...]
ERROR: Models with keyword arguments need special treatment to be used with particle methods. Please run:
using Libtask; Libtask.@might_produce(m)
before sampling from this model with particle methods.
Stacktrace:
[...]
julia> using Libtask; Libtask.@might_produce(m)
julia> mean(sample(m(5.0), PG(20), 1000))
Sampling 100%|███████████████████████████████████████████████████████████████████| Time: 0:00:05
Mean
parameters mean
Symbol Float64
x 2.7182
julia> mean(sample(m(5.0; n=10.0), PG(20), 1000))
Sampling 100%|███████████████████████████████████████████████████████████████████| Time: 0:00:04
Mean
parameters mean
Symbol Float64
x 7.4854
```
Closes#2007.
Copy file name to clipboardExpand all lines: HISTORY.md
+17Lines changed: 17 additions & 0 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -1,3 +1,20 @@
1
+
# 0.42.5
2
+
3
+
SMC and PG can now be used for models with keyword arguments, albeit with one requirement: the user must mark the model function as being able to produce.
4
+
For example, if the model is
5
+
6
+
```julia
7
+
@modelfoo(x; y) = a ~Normal(x, y)
8
+
```
9
+
10
+
then before samping from this with SMC or PG, you will have to run
11
+
12
+
```julia
13
+
using Turing
14
+
15
+
@might_produce(foo)
16
+
```
17
+
1
18
# 0.42.4
2
19
3
20
Fixes a typo that caused NUTS to perform one less adaptation step than in versions prior to 0.41.
Copy file name to clipboardExpand all lines: docs/src/api.md
+15-13Lines changed: 15 additions & 13 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -44,6 +44,8 @@ even though [`Prior()`](@ref) is actually defined in the `Turing.Inference` modu
44
44
|`LogDensityFunction`|[`DynamicPPL.LogDensityFunction`](@extref)| A struct containing all information about how to evaluate a model. Mostly for advanced users |
45
45
|`@addlogprob!`|[`DynamicPPL.@addlogprob!`](@extref)| Add arbitrary log-probability terms during model evaluation |
46
46
|`setthreadsafe`|[`DynamicPPL.setthreadsafe`](@extref)| Mark a model as requiring threadsafe evaluation |
47
+
|`might_produce`|[`Libtask.might_produce`](@extref)| Mark a method signature as potentially calling `Libtask.produce`|
48
+
|`@might_produce`|[`Libtask.@might_produce`](@extref)| Mark a function name as potentially calling `Libtask.produce`|
47
49
48
50
### Inference
49
51
@@ -110,19 +112,19 @@ Turing.jl provides several strategies to initialise parameters for models.
110
112
111
113
See the [docs of AdvancedVI.jl](https://turinglang.org/AdvancedVI.jl/stable/) for detailed usage and the [variational inference tutorial](https://turinglang.org/docs/tutorials/09-variational-inference/) for a basic walkthrough.
|`q_locationscale`|[`Turing.Variational.q_locationscale`](@ref)| Find a numerically non-degenerate initialization for a location-scale variational family |
117
-
|`q_meanfield_gaussian`|[`Turing.Variational.q_meanfield_gaussian`](@ref)| Find a numerically non-degenerate initialization for a mean-field Gaussian family |
118
-
|`q_fullrank_gaussian`|[`Turing.Variational.q_fullrank_gaussian`](@ref)| Find a numerically non-degenerate initialization for a full-rank Gaussian family |
119
-
|`KLMinRepGradDescent`|[`Turing.Variational.KLMinRepGradDescent`](@ref)| KL divergence minimization via stochastic gradient descent with the reparameterization gradient |
120
-
|`KLMinRepGradProxDescent`|[`Turing.Variational.KLMinRepGradProxDescent`](@ref)| KL divergence minimization via stochastic proximal gradient descent with the reparameterization gradient over location-scale variational families |
121
-
|`KLMinScoreGradDescent`|[`Turing.Variational.KLMinScoreGradDescent`](@ref)| KL divergence minimization via stochastic gradient descent with the score gradient |
122
-
|`KLMinWassFwdBwd`|[`Turing.Variational.KLMinWassFwdBwd`](@ref)| KL divergence minimization via Wasserstein proximal gradient descent |
123
-
|`KLMinNaturalGradDescent`|[`Turing.Variational.KLMinNaturalGradDescent`](@ref)| KL divergence minimization via natural gradient descent |
124
-
|`KLMinSqrtNaturalGradDescent`|[`Turing.Variational.KLMinSqrtNaturalGradDescent`](@ref)| KL divergence minimization via natural gradient descent in the square-root parameterization |
125
-
|`FisherMinBatchMatch`|[`Turing.Variational.FisherMinBatchMatch`](@ref)| Covariance-weighted Fisher divergence minimization via the batch-and-match algorithm |
|`q_locationscale`|[`Turing.Variational.q_locationscale`](@ref)| Find a numerically non-degenerate initialization for a location-scale variational family |
119
+
|`q_meanfield_gaussian`|[`Turing.Variational.q_meanfield_gaussian`](@ref)| Find a numerically non-degenerate initialization for a mean-field Gaussian family |
120
+
|`q_fullrank_gaussian`|[`Turing.Variational.q_fullrank_gaussian`](@ref)| Find a numerically non-degenerate initialization for a full-rank Gaussian family |
121
+
|`KLMinRepGradDescent`|[`AdvancedVI.KLMinRepGradDescent`](@extref)| KL divergence minimization via stochastic gradient descent with the reparameterization gradient |
122
+
|`KLMinRepGradProxDescent`|[`AdvancedVI.KLMinRepGradProxDescent`](@extref)| KL divergence minimization via stochastic proximal gradient descent with the reparameterization gradient over location-scale variational families |
123
+
|`KLMinScoreGradDescent`|[`AdvancedVI.KLMinScoreGradDescent`](@extref)| KL divergence minimization via stochastic gradient descent with the score gradient |
124
+
|`KLMinWassFwdBwd`|[`AdvancedVI.KLMinWassFwdBwd`](@extref)| KL divergence minimization via Wasserstein proximal gradient descent |
125
+
|`KLMinNaturalGradDescent`|[`AdvancedVI.KLMinNaturalGradDescent`](@extref)| KL divergence minimization via natural gradient descent |
126
+
|`KLMinSqrtNaturalGradDescent`|[`AdvancedVI.KLMinSqrtNaturalGradDescent`](@extref)| KL divergence minimization via natural gradient descent in the square-root parameterization |
127
+
|`FisherMinBatchMatch`|[`AdvancedVI.FisherMinBatchMatch`](@extref)| Covariance-weighted Fisher divergence minimization via the batch-and-match algorithm |
0 commit comments