-
Notifications
You must be signed in to change notification settings - Fork 38
add ForwardDiff@1 #378
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add ForwardDiff@1 #378
Changes from all commits
d114283
8cd3f7b
bdb1eb4
af33293
55938c6
9d305d8
6cf4d24
67fa51e
f209df8
6d0dbe1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -145,12 +145,10 @@ end | |||||||||||||||||||||||||||||||||||||||
@testset "Multivariate" begin | ||||||||||||||||||||||||||||||||||||||||
vector_dists = [ | ||||||||||||||||||||||||||||||||||||||||
Dirichlet(2, 3), | ||||||||||||||||||||||||||||||||||||||||
Dirichlet([1000 * one(Float64), eps(Float64)]), | ||||||||||||||||||||||||||||||||||||||||
Dirichlet([eps(Float64), 1000 * one(Float64)]), | ||||||||||||||||||||||||||||||||||||||||
Dirichlet([10.0, 0.1]), | ||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure these new tests could replace the current ones. Tests like Dirichlet([1000 * one(Float64), eps(Float64)]),
Dirichlet([eps(Float64), 1000 * one(Float64)]), are aimed at the numerical stability of very extrate examples of Dirichlet distributions, i.e. one axis has a very tiny probability mass in average. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's actually the sample that's the problem. For the sample The issue comes from these lines: Bijectors.jl/src/bijectors/simplex.jl Lines 89 to 90 in d8d781b
As The difference between FD 0.10 and FD 1.0 is that the new version sets the derivative to 0 if As far as I can tell the fact that it used to work with FD 0.10 might have been a happy accident – I wrote more about this in a comment above, but (to me) it makes sense for FD to set the derivative to 0 at the point There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am still not fully sure how to resolve this though, which is why I haven't really come back to this PR. Obviously changing the sample fixes the tests (and the easiest way to change the sample was to change the distribution from which it was drawn), but I can't tell if there's a workaround in the code that makes it work again for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Pinging @devmotion for your thoughts too :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, I'm not sure. I always had the feeling that this stick-breaking transform (explained in eg the Stan docs) can be numerically problematic. I also always thought that these eps workarounds are unsatisfying. But I'm not sure what exactly would be broken when they would be removed, maybe would be interesting to see. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Finally came back to this; I've been holding off on merging because I haven't had time to properly look into this. Removing eps actually makes the original tests pass!... But for all the wrong reasons: Bijectors.jl/src/bijectors/simplex.jl Lines 28 to 44 in fbaf783
For the sample Right now it's just the ForwardDiff jacobian used to verify the manual logabsdetjac that returns I tried 2*eps and unfortunately that made ForwardDiff even less happy. So I think I'm inclined to merge (and @yebai really wants this PR in). |
||||||||||||||||||||||||||||||||||||||||
Dirichlet([0.1, 10.0]), | ||||||||||||||||||||||||||||||||||||||||
MvNormal(randn(10), Diagonal(exp.(randn(10)))), | ||||||||||||||||||||||||||||||||||||||||
MvLogNormal(MvNormal(randn(10), Diagonal(exp.(randn(10))))), | ||||||||||||||||||||||||||||||||||||||||
Dirichlet([1000 * one(Float64), eps(Float64)]), | ||||||||||||||||||||||||||||||||||||||||
Dirichlet([eps(Float64), 1000 * one(Float64)]), | ||||||||||||||||||||||||||||||||||||||||
penelopeysm marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||
MvTDist(1, randn(10), Matrix(Diagonal(exp.(randn(10))))), | ||||||||||||||||||||||||||||||||||||||||
transformed(MvNormal(randn(10), Diagonal(exp.(randn(10))))), | ||||||||||||||||||||||||||||||||||||||||
transformed(MvLogNormal(MvNormal(randn(10), Diagonal(exp.(randn(10)))))), | ||||||||||||||||||||||||||||||||||||||||
|
@@ -173,15 +171,7 @@ end | |||||||||||||||||||||||||||||||||||||||
# similar to what we do in test/transform.jl for Dirichlet | ||||||||||||||||||||||||||||||||||||||||
if dist isa Dirichlet | ||||||||||||||||||||||||||||||||||||||||
b = Bijectors.SimplexBijector() | ||||||||||||||||||||||||||||||||||||||||
# HACK(torfjelde): Calling `rand(dist)` will sometimes lead to `[0.999..., 0.0]` | ||||||||||||||||||||||||||||||||||||||||
# which in turn will lead to differences between `ForwardDiff.jacobian` | ||||||||||||||||||||||||||||||||||||||||
# and `logabsdetjac` due to how we handle the boundary values in `SimplexBijector`. | ||||||||||||||||||||||||||||||||||||||||
# We therefore test the realizations _on_ the boundary rather if we're near the boundary. | ||||||||||||||||||||||||||||||||||||||||
x = if any(rand(dist) .> 0.9999) | ||||||||||||||||||||||||||||||||||||||||
[0.0, 1.0][sortperm(rand(dist))] | ||||||||||||||||||||||||||||||||||||||||
else | ||||||||||||||||||||||||||||||||||||||||
rand(dist) | ||||||||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||||||||
x = rand(dist) | ||||||||||||||||||||||||||||||||||||||||
y = b(x) | ||||||||||||||||||||||||||||||||||||||||
@test b(param(x)) isa TrackedArray | ||||||||||||||||||||||||||||||||||||||||
@test logabsdet(ForwardDiff.jacobian(b, x)[:, 1:(end - 1)])[1] ≈ | ||||||||||||||||||||||||||||||||||||||||
|
Uh oh!
There was an error while loading. Please reload this page.