Skip to content

Commit 6f0d383

Browse files
acertaindevmotionpenelopeysmyebai
authored
Implement logabsdetjac for Inverse{<:TruncatedBijector} for better numerical stability (#325)
* implement logabsdetjac for Inverse{<:TruncatedBijector} for better numerical stability Example of previous badness: logabsdetjac(inverse(bijector(Uniform(-1,1))), 80) = -Inf (is now -79.30685281944005) * promote at start, try to fix test * fix? * fix test * back to abs formula * Bump minor version * Actually bump minor version * Add test for Uniform(-1, 1), y=80 * Tweak test to be more discerning * Test forward logabsdetjac too * Update Project.toml Co-authored-by: David Widmann <[email protected]> --------- Co-authored-by: David Widmann <[email protected]> Co-authored-by: Penelope Yong <[email protected]> Co-authored-by: Hong Ge <[email protected]>
1 parent 2b7498a commit 6f0d383

File tree

3 files changed

+38
-1
lines changed

3 files changed

+38
-1
lines changed

src/bijectors/truncated.jl

+22
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,28 @@ end
6868

6969
with_logabsdet_jacobian(b::TruncatedBijector, x) = transform(b, x), logabsdetjac(b, x)
7070

71+
function truncated_inv_logabsdetjac(y, a, b)
72+
y, a, b = promote(y, a, b)
73+
lowerbounded, upperbounded = isfinite(a), isfinite(b)
74+
if lowerbounded && upperbounded
75+
abs_y = abs(y)
76+
return log(b - a) - abs_y - 2 * LogExpFunctions.log1pexp(-abs_y)
77+
elseif lowerbounded || upperbounded
78+
return y
79+
else
80+
return zero(y)
81+
end
82+
end
83+
84+
function logabsdetjac(ib::Inverse{<:TruncatedBijector}, y)
85+
a, b = ib.orig.lb, ib.orig.ub
86+
return sum(truncated_inv_logabsdetjac.(y, a, b))
87+
end
88+
89+
function with_logabsdet_jacobian(ib::Inverse{<:TruncatedBijector}, y)
90+
return transform(ib, y), logabsdetjac(ib, y)
91+
end
92+
7193
# It's only monotonically decreasing if it's only upper-bounded.
7294
# In the multivariate case, we can only say something reasonable if entries are monotonic.
7395
function is_monotonically_increasing(b::TruncatedBijector)

test/bijectors/ordered.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ end
7878

7979
@testset "correctness" begin
8080
num_samples = 10_000
81-
num_adapts = 1_000
81+
num_adapts = 5_000
8282
@testset "k = $k" for k in [2, 3, 5]
8383
@testset "$(typeof(dist))" for dist in [
8484
# Unconstrained

test/interface.jl

+15
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,21 @@ contains(predicate::Function, b::Stacked) = any(contains.(predicate, b.bs))
9696
logabsdetjac(inverse(b), y) atol = 1e-6
9797
end
9898
end
99+
100+
@testset "logabsdetjac numerical stability: Bijectors.jl#325" begin
101+
d = Uniform(-1, 1)
102+
b = bijector(d)
103+
y = 80
104+
# x needs higher precision to be calculated correctly, otherwise
105+
# logpdf_with_trans returns -Inf
106+
d_big = Uniform(big(-1.0), big(1.0))
107+
b_big = bijector(d_big)
108+
x_big = inverse(b_big)(big(y))
109+
@test logpdf(d_big, x_big) + logabsdetjacinv(b, y)
110+
logpdf_with_trans(d_big, x_big, true) atol = 1e-14
111+
@test logpdf(d_big, x_big) - logabsdetjac(b, x_big)
112+
logpdf_with_trans(d_big, x_big, true) atol = 1e-14
113+
end
99114
end
100115

101116
@testset "Truncated" begin

0 commit comments

Comments
 (0)