For convenience, the old version of expected_loglikelihood (Gauss-Hermite quadrature method) looked like this:
|
# Compute the expected_loglikelihood over a collection of observations and marginal distributions |
|
function expected_loglikelihood( |
|
gh::GaussHermiteExpectation, lik, q_f::AbstractVector{<:Normal}, y::AbstractVector |
|
) |
|
# Compute the expectation via Gauss-Hermite quadrature |
|
# using a reparameterisation by change of variable |
|
# (see e.g. en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature) |
|
return sum(Broadcast.instantiate( |
|
Broadcast.broadcasted(y, q_f) do yᵢ, q_fᵢ # Loop over every pair |
|
# of marginal distribution q(fᵢ) and observation yᵢ |
|
expected_loglikelihood(gh, lik, q_fᵢ, yᵢ) |
|
end, |
|
)) |
|
end |
|
|
|
# Compute the expected_loglikelihood for one observation and a marginal distributions |
|
function expected_loglikelihood(gh::GaussHermiteExpectation, lik, q_f::Normal, y) |
|
μ = mean(q_f) |
|
σ̃ = sqrt2 * std(q_f) |
|
return invsqrtπ * sum(Broadcast.instantiate( |
|
Broadcast.broadcasted(gh.xs, gh.ws) do x, w # Loop over every |
|
# pair of Gauss-Hermite point x with weight w |
|
f = σ̃ * x + μ |
|
loglikelihood(lik(f), y) * w |
|
end, |
|
)) |
|
end |
#90 introduced a work-around/hack for two (possibly interrelated) issues of that implementation:
The type-instability of the broadcast could be related to JuliaLang/julia#45748 (see also JuliaGaussianProcesses/KernelFunctions.jl#458, which documents a strange behavior in which inference depends on previous evaluations, which I also observed in the Broadcast construction, but could not resolve).
I therefore tried to get rid of the Broadcast entirely, hoping that type stability would improve performance. First I tried a custom implementation of pairwise sum for two function arguments (i.e. I implemented sum(f, X, Y), which is equivalent to mapreduce(f, +, X, Y) based on the implementation of mapreduce(f, +, X) in Base, see #77 for the reason behind this). That implementation can be found in here.
Although that implementation makes the function type stable, it was still not very performant. For this reason, in #90 I chose an implementation which allocates an explicit array, which I believed to be more Zygote-friendly. The large improvements over the old versions seen in the benchmarks confirmed this intuition (see #90 (comment)).
However, the solution is not very clean, as it pays for this performance improvement with additional allocations in the forward pass. It is also unclear whether this implementation will be as beneficial or generalize to other AD backends. The potentially better approach would be to define an rrule for the broadcasted sum, as suggested in #90 (comment) (although even then it would be nice to have the function be type-stable anyway).
For convenience, the old version of
expected_loglikelihood(Gauss-Hermite quadrature method) looked like this:GPLikelihoods.jl/src/expectations.jl
Lines 83 to 109 in e9b7da9
#90 introduced a work-around/hack for two (possibly interrelated) issues of that implementation:
NegativeBinomialLikelihood,The type-instability of the broadcast could be related to JuliaLang/julia#45748 (see also JuliaGaussianProcesses/KernelFunctions.jl#458, which documents a strange behavior in which inference depends on previous evaluations, which I also observed in the
Broadcastconstruction, but could not resolve).I therefore tried to get rid of the
Broadcastentirely, hoping that type stability would improve performance. First I tried a custom implementation of pairwise sum for two function arguments (i.e. I implementedsum(f, X, Y), which is equivalent tomapreduce(f, +, X, Y)based on the implementation ofmapreduce(f, +, X)in Base, see #77 for the reason behind this). That implementation can be found in here.Although that implementation makes the function type stable, it was still not very performant. For this reason, in #90 I chose an implementation which allocates an explicit array, which I believed to be more Zygote-friendly. The large improvements over the old versions seen in the benchmarks confirmed this intuition (see #90 (comment)).
However, the solution is not very clean, as it pays for this performance improvement with additional allocations in the forward pass. It is also unclear whether this implementation will be as beneficial or generalize to other AD backends. The potentially better approach would be to define an
rrulefor the broadcasted sum, as suggested in #90 (comment) (although even then it would be nice to have the function be type-stable anyway).