Skip to content

Commit 36008f9

Browse files
Add getmodel and setmodel from/to LogDensityFunction (#626)
* initial copy and paste * add some test * Update ext/DynamicPPLReverseDiffExt.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update ext/DynamicPPLReverseDiffExt.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * update to the new implementation according to Turing * Update src/logdensityfunction.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * error fixes * Update src/logdensityfunction.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logdensityfunction.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * add `HypothesisTests` to turing test dep * version bump --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent dfdc155 commit 36008f9

File tree

4 files changed

+63
-2
lines changed

4 files changed

+63
-2
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.28.1"
3+
version = "0.28.2"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/logdensityfunction.jl

+40
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,46 @@ function getcontext(f::LogDensityFunction)
7676
return f.context === nothing ? leafcontext(f.model.context) : f.context
7777
end
7878

79+
"""
80+
getmodel(f)
81+
82+
Return the `DynamicPPL.Model` wrapped in the given log-density function `f`.
83+
"""
84+
getmodel(f::LogDensityProblemsAD.ADGradientWrapper) =
85+
getmodel(LogDensityProblemsAD.parent(f))
86+
getmodel(f::DynamicPPL.LogDensityFunction) = f.model
87+
88+
"""
89+
setmodel(f, model[, adtype])
90+
91+
Set the `DynamicPPL.Model` in the given log-density function `f` to `model`.
92+
93+
!!! warning
94+
Note that if `f` is a `LogDensityProblemsAD.ADGradientWrapper` wrapping a
95+
`DynamicPPL.LogDensityFunction`, performing an update of the `model` in `f`
96+
might require recompilation of the gradient tape, depending on the AD backend.
97+
"""
98+
function setmodel(
99+
f::LogDensityProblemsAD.ADGradientWrapper,
100+
model::DynamicPPL.Model,
101+
adtype::ADTypes.AbstractADType,
102+
)
103+
# TODO: Should we handle `SciMLBase.NoAD`?
104+
# For an `ADGradientWrapper` we do the following:
105+
# 1. Update the `Model` in the underlying `LogDensityFunction`.
106+
# 2. Re-construct the `ADGradientWrapper` using `ADgradient` using the provided `adtype`
107+
# to ensure that the recompilation of gradient tapes, etc. also occur. For example,
108+
# ReverseDiff.jl in compiled mode will cache the compiled tape, which means that just
109+
# replacing the corresponding field with the new model won't be sufficient to obtain
110+
# the correct gradients.
111+
return LogDensityProblemsAD.ADgradient(
112+
adtype, setmodel(LogDensityProblemsAD.parent(f), model)
113+
)
114+
end
115+
function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model)
116+
return Accessors.@set f.model = model
117+
end
118+
79119
# HACK: heavy usage of `AbstractSampler` for, well, _everything_, is being phased out. In the mean time
80120
# we need to define these annoying methods to ensure that we stay compatible with everything.
81121
getsampler(f::LogDensityFunction) = getsampler(getcontext(f))

test/logdensityfunction.jl

+21-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,24 @@
1-
using Test, DynamicPPL, LogDensityProblems
1+
using Test, DynamicPPL, ADTypes, LogDensityProblems, LogDensityProblemsAD, ReverseDiff
2+
3+
@testset "`getmodel` and `setmodel`" begin
4+
@testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS
5+
model = DynamicPPL.TestUtils.DEMO_MODELS[1]
6+
= DynamicPPL.LogDensityFunction(model)
7+
@test DynamicPPL.getmodel(ℓ) == model
8+
@test DynamicPPL.setmodel(ℓ, model).model == model
9+
10+
# ReverseDiff related
11+
∇ℓ = LogDensityProblemsAD.ADgradient(:ReverseDiff, ℓ; compile=Val(false))
12+
@test DynamicPPL.getmodel(∇ℓ) == model
13+
@test DynamicPPL.getmodel(DynamicPPL.setmodel(∇ℓ, model, AutoReverseDiff())) ==
14+
model
15+
∇ℓ = LogDensityProblemsAD.ADgradient(:ReverseDiff, ℓ; compile=Val(true))
16+
new_∇ℓ = DynamicPPL.setmodel(∇ℓ, model, AutoReverseDiff())
17+
@test DynamicPPL.getmodel(new_∇ℓ) == model
18+
# HACK(sunxd): rely on internal implementation detail, i.e., naming of `compiledtape`
19+
@test new_∇ℓ.compiledtape != ∇ℓ.compiledtape
20+
end
21+
end
222

323
@testset "LogDensityFunction" begin
424
@testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS

test/turing/Project.toml

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
[deps]
22
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
33
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
4+
HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5"
45
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
56
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
67
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

0 commit comments

Comments
 (0)