Skip to content

Commit f0c31f0

Browse files
ADTypes + ADgradient Performance (#727)
* ADTypes interop * Improve comment * Bump patch version * Formatting * Formatting * Improve documentation * Testing infrastructure * Remove extras from main Project toml * Apply some basic tests * Locate tests better * Internal _make_ad_gradient * Mark failing tests as broken * Formatting * Update Project.toml * Updates * Bump patch version * Bump patch again --------- Co-authored-by: Penelope Yong <[email protected]>
1 parent 5a58571 commit f0c31f0

File tree

6 files changed

+37
-44
lines changed

6 files changed

+37
-44
lines changed

Project.toml

+1-12
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.31.2"
3+
version = "0.31.3"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -30,15 +30,13 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3030
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
3131
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
3232
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
33-
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
3433
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
3534

3635
[extensions]
3736
DynamicPPLChainRulesCoreExt = ["ChainRulesCore"]
3837
DynamicPPLEnzymeCoreExt = ["EnzymeCore"]
3938
DynamicPPLForwardDiffExt = ["ForwardDiff"]
4039
DynamicPPLMCMCChainsExt = ["MCMCChains"]
41-
DynamicPPLReverseDiffExt = ["ReverseDiff"]
4240
DynamicPPLZygoteRulesExt = ["ZygoteRules"]
4341

4442
[compat]
@@ -63,15 +61,6 @@ MacroTools = "0.5.6"
6361
OrderedCollections = "1"
6462
Random = "1.6"
6563
Requires = "1"
66-
ReverseDiff = "1"
6764
Test = "1.6"
6865
ZygoteRules = "0.2"
6966
julia = "1.10"
70-
71-
[extras]
72-
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
73-
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
74-
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
75-
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
76-
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
77-
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

ext/DynamicPPLReverseDiffExt.jl

-26
This file was deleted.

src/logdensityfunction.jl

+16
Original file line numberDiff line numberDiff line change
@@ -144,3 +144,19 @@ function LogDensityProblems.capabilities(::Type{<:LogDensityFunction})
144144
end
145145
# TODO: should we instead implement and call on `length(f.varinfo)` (at least in the cases where no sampler is involved)?
146146
LogDensityProblems.dimension(f::LogDensityFunction) = length(getparams(f))
147+
148+
# This is important for performance -- one needs to provide `ADGradient` with a vector of
149+
# parameters, or DifferentiationInterface will not have sufficient information to e.g.
150+
# compile a rule for Mooncake (because it won't know the type of the input), or pre-allocate
151+
# a tape when using ReverseDiff.jl.
152+
function _make_ad_gradient(ad::ADTypes.AbstractADType, ℓ::LogDensityFunction)
153+
x = map(identity, getparams(ℓ)) # ensure we concretise the elements of the params
154+
return LogDensityProblemsAD.ADgradient(ad, ℓ; x)
155+
end
156+
157+
function LogDensityProblemsAD.ADgradient(ad::ADTypes.AutoMooncake, f::LogDensityFunction)
158+
return _make_ad_gradient(ad, f)
159+
end
160+
function LogDensityProblemsAD.ADgradient(ad::ADTypes.AutoReverseDiff, f::LogDensityFunction)
161+
return _make_ad_gradient(ad, f)
162+
end

test/Project.toml

+3
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
66
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
77
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
88
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
9+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
910
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
1011
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1112
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
@@ -17,6 +18,7 @@ LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
1718
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
1819
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
1920
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
21+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
2022
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
2123
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2224
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
@@ -43,6 +45,7 @@ LogDensityProblems = "2"
4345
LogDensityProblemsAD = "1.7.0"
4446
MCMCChains = "6.0.4"
4547
MacroTools = "0.5.6"
48+
Mooncake = "0.4.50"
4649
ReverseDiff = "1"
4750
StableRNGs = "1"
4851
Tracker = "0.2.23"

test/ad.jl

+15-6
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
@testset "AD: ForwardDiff and ReverseDiff" begin
1+
@testset "AD: ForwardDiff, ReverseDiff, and Mooncake" begin
22
@testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS
33
f = DynamicPPL.LogDensityFunction(m)
44
rand_param_values = DynamicPPL.TestUtils.rand_prior_true(m)
@@ -17,11 +17,20 @@
1717
θ = convert(Vector{Float64}, varinfo[:])
1818
logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ad_forwarddiff_f, θ)
1919

20-
@testset "ReverseDiff with compile=$compile" for compile in (false, true)
21-
adtype = ADTypes.AutoReverseDiff(; compile=compile)
22-
ad_f = LogDensityProblemsAD.ADgradient(adtype, f)
23-
_, grad = LogDensityProblems.logdensity_and_gradient(ad_f, θ)
24-
@test grad ref_grad
20+
@testset "$adtype" for adtype in [
21+
ADTypes.AutoReverseDiff(; compile=false),
22+
ADTypes.AutoReverseDiff(; compile=true),
23+
ADTypes.AutoMooncake(; config=nothing),
24+
]
25+
# Mooncake can't currently handle something that is going on in
26+
# SimpleVarInfo{<:VarNamedVector}. Disable all SimpleVarInfo tests for now.
27+
if adtype isa ADTypes.AutoMooncake && varinfo isa DynamicPPL.SimpleVarInfo
28+
@test_broken 1 == 0
29+
else
30+
ad_f = LogDensityProblemsAD.ADgradient(adtype, f)
31+
_, grad = LogDensityProblems.logdensity_and_gradient(ad_f, θ)
32+
@test grad ref_grad
33+
end
2534
end
2635
end
2736
end

test/runtests.jl

+2
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@ using DynamicPPL
44
using AbstractMCMC
55
using AbstractPPL
66
using Bijectors
7+
using DifferentiationInterface
78
using Distributions
89
using DistributionsAD
910
using Documenter
1011
using ForwardDiff
1112
using LogDensityProblems, LogDensityProblemsAD
1213
using MacroTools
1314
using MCMCChains
15+
using Mooncake: Mooncake
1416
using Tracker
1517
using ReverseDiff
1618
using Zygote

0 commit comments

Comments
 (0)