Skip to content

Making smf ForwardDiff-friendly #242

Open
@cscherrer

Description

@cscherrer

We should be able to tell ForwardDiff that the derivative of smf is logdensityof. Here are timings for these separately for a StdNormal:

julia> using MeasureBase, ForwardDiff, BenchmarkTools

julia> using ForwardDiff: Dual

julia> @btime smf(StdNormal(), x) setup=(x=randn())
  5.660 ns (0 allocations: 0 bytes)
0.373722

julia> @btime logdensityof(StdNormal(), x) setup=(x=randn())
  2.294 ns (0 allocations: 0 bytes)
-0.935558

If we compute both, there's no extra overhead, and we even save a little;

julia> @btime (smf(StdNormal(), x), logdensityof(StdNormal(), x)) setup=(x=randn())
  7.151 ns (0 allocations: 0 bytes)
(0.206069, -1.25525)

It would be good to "teach" ForwardDiff to do this, because it's currently much slower:

julia> @btime ForwardDiff.derivative(x -> smf(StdNormal(), x), x)  setup=(x=randn())
  13.181 ns (0 allocations: 0 bytes)
0.122241

julia> @btime smf(StdNormal(), Dual{}(x, one(x)))  setup=(x=randn())
  14.506 ns (0 allocations: 0 bytes)
Dual{Nothing}(0.572796,0.392282)

I'd think this ought to do it:

julia> function MeasureBase.smf::StdNormal, x::Dual{TAG}) where TAG
           val = ForwardDiff.value(x)
           Δ = ForwardDiff.partials(x)
           Dual{TAG}(smf(μ, val), Δ * densityof(μ, val))
       end

But it doesn't:

julia> @btime ForwardDiff.derivative(x -> smf(StdNormal(), x), x)  setup=(x=randn())
  13.081 ns (0 allocations: 0 bytes)
0.370831

julia> @btime smf(StdNormal(), Dual{}(x, one(x)))  setup=(x=randn())
  14.898 ns (0 allocations: 0 bytes)
Dual{Nothing}(0.724177,0.334163)

How can we get this working properly?

It should also work if we call ForwardDiff.derivative(smf(StdNormal()), x)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions