Skip to content

Commit 072234d

Browse files
authored
Bump DifferentiationInterface to 0.7 (#922)
* DI 0.7 * Fix strictness failure with DifferentiationInterface 0.7 * Bump patch * Use `LogDensityAt` callable struct instead of closure * Use type parameters
1 parent 27d4378 commit 072234d

File tree

4 files changed

+29
-7
lines changed

4 files changed

+29
-7
lines changed

HISTORY.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# DynamicPPL Changelog
22

3+
## 0.36.4
4+
5+
Added compatibility with DifferentiationInterface.jl 0.7.
6+
37
## 0.36.3
48

59
Moved the `bijector(model)`, where `model` is a `DynamicPPL.Model`, function from the Turing main repo.

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.36.3"
3+
version = "0.36.4"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -54,7 +54,7 @@ ChainRulesCore = "1"
5454
Chairmarks = "1.3.1"
5555
Compat = "4"
5656
ConstructionBase = "1.5.4"
57-
DifferentiationInterface = "0.6.41"
57+
DifferentiationInterface = "0.6.41, 0.7"
5858
Distributions = "0.25"
5959
DocStringExtensions = "0.9"
6060
EnzymeCore = "0.6 - 0.8"

src/logdensityfunction.jl

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,7 @@ struct LogDensityFunction{
124124
# Get a set of dummy params to use for prep
125125
x = map(identity, varinfo[:])
126126
if use_closure(adtype)
127-
prep = DI.prepare_gradient(
128-
x -> logdensity_at(x, model, varinfo, context), adtype, x
129-
)
127+
prep = DI.prepare_gradient(LogDensityAt(model, varinfo, context), adtype, x)
130128
else
131129
prep = DI.prepare_gradient(
132130
logdensity_at,
@@ -184,6 +182,26 @@ function logdensity_at(
184182
return getlogp(last(evaluate!!(model, varinfo_new, context)))
185183
end
186184

185+
"""
186+
LogDensityAt{M<:Model,V<:AbstractVarInfo,C<:AbstractContext}(
187+
model::M
188+
varinfo::V
189+
context::C
190+
)
191+
192+
A callable struct that serves the same purpose as `x -> logdensity_at(x, model,
193+
varinfo, context)`.
194+
"""
195+
struct LogDensityAt{M<:Model,V<:AbstractVarInfo,C<:AbstractContext}
196+
model::M
197+
varinfo::V
198+
context::C
199+
end
200+
function (ld::LogDensityAt)(x::AbstractVector)
201+
varinfo_new = unflatten(ld.varinfo, x)
202+
return getlogp(last(evaluate!!(ld.model, varinfo_new, ld.context)))
203+
end
204+
187205
### LogDensityProblems interface
188206

189207
function LogDensityProblems.capabilities(
@@ -209,7 +227,7 @@ function LogDensityProblems.logdensity_and_gradient(
209227
# branches happen to return different types)
210228
return if use_closure(f.adtype)
211229
DI.value_and_gradient(
212-
x -> logdensity_at(x, f.model, f.varinfo, f.context), f.prep, f.adtype, x
230+
LogDensityAt(f.model, f.varinfo, f.context), f.prep, f.adtype, x
213231
)
214232
else
215233
DI.value_and_gradient(

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ Aqua = "0.8"
3838
Bijectors = "0.15.1"
3939
Combinatorics = "1"
4040
Compat = "4.3.0"
41-
DifferentiationInterface = "0.6.41"
41+
DifferentiationInterface = "0.6.41, 0.7"
4242
Distributions = "0.25"
4343
DistributionsAD = "0.6.3"
4444
Documenter = "1"

0 commit comments

Comments
 (0)