Open
Description
We should be able to tell ForwardDiff that the derivative of smf
is logdensityof
. Here are timings for these separately for a StdNormal
:
julia> using MeasureBase, ForwardDiff, BenchmarkTools
julia> using ForwardDiff: Dual
julia> @btime smf(StdNormal(), x) setup=(x=randn())
5.660 ns (0 allocations: 0 bytes)
0.373722
julia> @btime logdensityof(StdNormal(), x) setup=(x=randn())
2.294 ns (0 allocations: 0 bytes)
-0.935558
If we compute both, there's no extra overhead, and we even save a little;
julia> @btime (smf(StdNormal(), x), logdensityof(StdNormal(), x)) setup=(x=randn())
7.151 ns (0 allocations: 0 bytes)
(0.206069, -1.25525)
It would be good to "teach" ForwardDiff to do this, because it's currently much slower:
julia> @btime ForwardDiff.derivative(x -> smf(StdNormal(), x), x) setup=(x=randn())
13.181 ns (0 allocations: 0 bytes)
0.122241
julia> @btime smf(StdNormal(), Dual{}(x, one(x))) setup=(x=randn())
14.506 ns (0 allocations: 0 bytes)
Dual{Nothing}(0.572796,0.392282)
I'd think this ought to do it:
julia> function MeasureBase.smf(μ::StdNormal, x::Dual{TAG}) where TAG
val = ForwardDiff.value(x)
Δ = ForwardDiff.partials(x)
Dual{TAG}(smf(μ, val), Δ * densityof(μ, val))
end
But it doesn't:
julia> @btime ForwardDiff.derivative(x -> smf(StdNormal(), x), x) setup=(x=randn())
13.081 ns (0 allocations: 0 bytes)
0.370831
julia> @btime smf(StdNormal(), Dual{}(x, one(x))) setup=(x=randn())
14.898 ns (0 allocations: 0 bytes)
Dual{Nothing}(0.724177,0.334163)
How can we get this working properly?
It should also work if we call ForwardDiff.derivative(smf(StdNormal()), x)
Metadata
Metadata
Assignees
Labels
No labels