Skip to content

add ForwardDiff@1 #378

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Bijectors"
uuid = "76274a88-744f-5084-9051-94815aaf08c4"
version = "0.15.6"
version = "0.15.7"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand Down Expand Up @@ -51,7 +51,7 @@ Distributions = "0.25.33"
DistributionsAD = "0.6"
DocStringExtensions = "0.9"
EnzymeCore = "0.8.4"
ForwardDiff = "0.10"
ForwardDiff = "0.10, 1.0.1"
Functors = "0.1, 0.2, 0.3, 0.4, 0.5"
InverseFunctions = "0.1"
IrrationalConstants = "0.1, 0.2"
Expand Down
1 change: 0 additions & 1 deletion src/Bijectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ _eps(::Type{<:Integer}) = eps(Float64)

function _clamp(x, a, b)
T = promote_type(typeof(x), typeof(a), typeof(b))
ϵ = _eps(T)
clamped_x = ifelse(x < a, convert(T, a), ifelse(x > b, convert(T, b), x))
DEBUG && _debug("x = $x, bounds = $((a, b)), clamped_x = $clamped_x")
return clamped_x
Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ Enzyme = "0.13.12"
EnzymeTestUtils = "0.2.1"
FillArrays = "1"
FiniteDifferences = "0.11, 0.12"
ForwardDiff = "0.10.12"
ForwardDiff = "0.10, 1.0.1"
Functors = "0.1, 0.2, 0.3, 0.4, 0.5"
InverseFunctions = "0.1"
LazyArrays = "1, 2"
Expand Down
16 changes: 3 additions & 13 deletions test/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,12 +145,10 @@ end
@testset "Multivariate" begin
vector_dists = [
Dirichlet(2, 3),
Dirichlet([1000 * one(Float64), eps(Float64)]),
Dirichlet([eps(Float64), 1000 * one(Float64)]),
Dirichlet([10.0, 0.1]),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure these new tests could replace the current ones. Tests like

Dirichlet([1000 * one(Float64), eps(Float64)]),
        Dirichlet([eps(Float64), 1000 * one(Float64)]),

are aimed at the numerical stability of very extrate examples of Dirichlet distributions, i.e. one axis has a very tiny probability mass in average.

Copy link
Member Author

@penelopeysm penelopeysm Apr 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's actually the sample that's the problem. For the sample x = [1.0, 0.0], the transformed variable is y = [36.0436] which is outside of the range for which Float64 is numerically stable.

The issue comes from these lines:

@inbounds z = LogExpFunctions.logistic(y[1] - log(T(K - 1)))
@inbounds x[1] = _clamp((z - ϵ) / (one(T) - 2ϵ), 0, 1)

As y[1] tends to +Inf, z tends to 1, and the expression (z - ϵ) / (one(T) - 2ϵ) tends towards 1.0000000000000002. If that expression is greater than 1, then it gets _clamped to 1, and the derivative is set to 0.

The difference between FD 0.10 and FD 1.0 is that the new version sets the derivative to 0 if (z - ϵ) / (one(T) - 2ϵ) is greater than, or equal to, 1. And that in turn means that there is a larger range of y[1] for which the derivative gets clamped. Unfortunately, Float64 36.0436 falls into that category (35.8 would have been fine, or alternatively, BigFloat is ok up until around 175).

As far as I can tell the fact that it used to work with FD 0.10 might have been a happy accident – I wrote more about this in a comment above, but (to me) it makes sense for FD to set the derivative to 0 at the point (z - ϵ) / (one(T) - 2ϵ) == 1.

Copy link
Member Author

@penelopeysm penelopeysm Apr 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am still not fully sure how to resolve this though, which is why I haven't really come back to this PR. Obviously changing the sample fixes the tests (and the easiest way to change the sample was to change the distribution from which it was drawn), but I can't tell if there's a workaround in the code that makes it work again for (z - ϵ) / (one(T) - 2ϵ) == 1.0, or more generally for large y.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pinging @devmotion for your thoughts too :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I'm not sure. I always had the feeling that this stick-breaking transform (explained in eg the Stan docs) can be numerically problematic. I also always thought that these eps workarounds are unsatisfying. But I'm not sure what exactly would be broken when they would be removed, maybe would be interesting to see.

Dirichlet([0.1, 10.0]),
MvNormal(randn(10), Diagonal(exp.(randn(10)))),
MvLogNormal(MvNormal(randn(10), Diagonal(exp.(randn(10))))),
Dirichlet([1000 * one(Float64), eps(Float64)]),
Dirichlet([eps(Float64), 1000 * one(Float64)]),
MvTDist(1, randn(10), Matrix(Diagonal(exp.(randn(10))))),
transformed(MvNormal(randn(10), Diagonal(exp.(randn(10))))),
transformed(MvLogNormal(MvNormal(randn(10), Diagonal(exp.(randn(10)))))),
Expand All @@ -173,15 +171,7 @@ end
# similar to what we do in test/transform.jl for Dirichlet
if dist isa Dirichlet
b = Bijectors.SimplexBijector()
# HACK(torfjelde): Calling `rand(dist)` will sometimes lead to `[0.999..., 0.0]`
# which in turn will lead to differences between `ForwardDiff.jacobian`
# and `logabsdetjac` due to how we handle the boundary values in `SimplexBijector`.
# We therefore test the realizations _on_ the boundary rather if we're near the boundary.
x = if any(rand(dist) .> 0.9999)
[0.0, 1.0][sortperm(rand(dist))]
else
rand(dist)
end
x = rand(dist)
y = b(x)
@test b(param(x)) isa TrackedArray
@test logabsdet(ForwardDiff.jacobian(b, x)[:, 1:(end - 1)])[1] ≈
Expand Down
35 changes: 7 additions & 28 deletions test/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,32 +153,10 @@ end
Dirichlet([eps(Float64), 1000 * one(Float64)]),
MvNormal(randn(10), Diagonal(exp.(randn(10)))),
MvLogNormal(MvNormal(randn(10), Diagonal(exp.(randn(10))))),
Dirichlet([1000 * one(Float64), eps(Float64)]),
Dirichlet([eps(Float64), 1000 * one(Float64)]),
]
for dist in vector_dists
if dist isa Dirichlet
single_sample_tests(dist)

# This should fail at the minute. Not sure what the correct way to test this is.

# Workaround for intermittent test failures, result of `logpdf_with_trans(dist, x, true)`
# is incorrect for `x == [0.9999999999999998, 0.0]`:
x =
if params(dist) ==
params(Dirichlet([1000 * one(Float64), eps(Float64)]))
[1.0, 0.0]
else
rand(dist)
end
# `Dirichlet` is no longer mapping between spaces of the same dimensionality,
# so the block below no longer works.
if !(dist isa Dirichlet)
logpdf_turing = logpdf_with_trans(dist, x, true)
J = ForwardDiff.jacobian(x -> link(dist, x), x)
@test logpdf(dist, x .+ ϵ) - _logabsdet(J) ≈ logpdf_turing
end

# Issue #12
stepsize = 1e10
dim = Bijectors.output_length(bijector(dist), length(dist))
Expand Down Expand Up @@ -240,14 +218,15 @@ end
@testset "uplo: $uplo" for uplo in [:L, :U]
dist = LKJCholesky(3, 1, uplo)
single_sample_tests(dist)

x = rand(dist)

inds = [
LinearIndices(size(x))[I] for I in CartesianIndices(size(x)) if
(uplo === :L && I[2] < I[1]) || (uplo === :U && I[2] > I[1])
]
J = ForwardDiff.jacobian(z -> link(dist, Cholesky(z, x.uplo, x.info)), x.UL)
# Remove columns of Jacobian that are all zero (i.e. those
# corresponding to entries above the diagonal for uplo = :U, or below
# the diagonal for uplo = :L). This slightly unscientific approach
# based on filter() is needed to handle both ForwardDiff 0.10 and 1 as
# the exact indices will differ for the two versions; see
# https://github.com/JuliaDiff/ForwardDiff.jl/issues/738.
inds = filter(i -> !all(iszero, J[:, i]), 1:size(J, 2))
J = J[:, inds]
logpdf_turing = logpdf_with_trans(dist, x, true)
@test logpdf(dist, x) - _logabsdet(J) ≈ logpdf_turing
Expand Down
Loading