Conversation
|
Your PR requires formatting changes to meet the project's style guidelines. Click here to view the suggested changes.diff --git a/lib/EnzymeTestUtils/src/test_rewind.jl b/lib/EnzymeTestUtils/src/test_rewind.jl
index 97649cc7..3f8653a2 100644
--- a/lib/EnzymeTestUtils/src/test_rewind.jl
+++ b/lib/EnzymeTestUtils/src/test_rewind.jl
@@ -61,26 +61,26 @@ end
"""
function test_rewind(
- f,
- fwd_ret_activity,
- rvs_ret_activity,
- args...;
- rng::Random.AbstractRNG=Random.default_rng(),
- fdm=FiniteDifferences.central_fdm(5, 1),
- fkwargs::NamedTuple=NamedTuple(),
- rtol::Real=1e-9,
- atol::Real=1e-9,
- testset_name=nothing,
- runtime_activity::Bool=false,
- output_tangent=nothing,
-)
+ f,
+ fwd_ret_activity,
+ rvs_ret_activity,
+ args...;
+ rng::Random.AbstractRNG = Random.default_rng(),
+ fdm = FiniteDifferences.central_fdm(5, 1),
+ fkwargs::NamedTuple = NamedTuple(),
+ rtol::Real = 1.0e-9,
+ atol::Real = 1.0e-9,
+ testset_name = nothing,
+ runtime_activity::Bool = false,
+ output_tangent = nothing,
+ )
# first, test reverse as normal with finite differences
- test_reverse(f, rvs_ret_activity, args...; rng=rng, fdm=fdm, fkwargs=fkwargs, rtol=rtol, atol=atol, testset_name=testset_name, runtime_activity=runtime_activity, output_tangent=output_tangent)
- # now, use the reverse rule to compare with the forward result
+ test_reverse(f, rvs_ret_activity, args...; rng = rng, fdm = fdm, fkwargs = fkwargs, rtol = rtol, atol = atol, testset_name = testset_name, runtime_activity = runtime_activity, output_tangent = output_tangent)
+ # now, use the reverse rule to compare with the forward result
if testset_name === nothing
testset_name = "test_rewind: $f with return activity $fwd_ret_activity on $(_string_activity(args))"
end
- @testset "$testset_name" begin
+ return @testset "$testset_name" begin
# test reverse rule to make sure it works with FD
# run fwd mode first
@@ -109,6 +109,6 @@ function test_rewind(
dy_ad = y_and_dy_ad[1]
# now run this back through reverse mode, using dy_ad from forward mode
# as the output tangent
- test_reverse(f, rvs_ret_activity, args...; rng=rng, fdm=fdm, fkwargs=fkwargs, rtol=rtol, atol=atol, testset_name=testset_name, runtime_activity=runtime_activity, output_tangent=dy_ad)
+ test_reverse(f, rvs_ret_activity, args...; rng = rng, fdm = fdm, fkwargs = fkwargs, rtol = rtol, atol = atol, testset_name = testset_name, runtime_activity = runtime_activity, output_tangent = dy_ad)
end
end |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2645 +/- ##
==========================================
- Coverage 75.14% 75.06% -0.09%
==========================================
Files 57 58 +1
Lines 17951 17974 +23
==========================================
+ Hits 13490 13492 +2
- Misses 4461 4482 +21 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
vchuravy
left a comment
There was a problem hiding this comment.
Needs some tests/usage
Does it make sense to use fdm here? Or should we simply be testing for consistency between forward and reverse?
| gauge is important and the finite-differences approach generates tangents in an arbitrary | ||
| gauge. In effect, this plays the derivatives _forward_, then in _reverse_, "rewinding" the |
There was a problem hiding this comment.
I must admit this is the first time I encounter the term gauge.
There was a problem hiding this comment.
Ah sorry it's physicist brain
There was a problem hiding this comment.
https://en.wikipedia.org/wiki/Gauge_fixing if you're interested. I'll try to rephrase
|
I used |
|
In general I think forward-mode through an eigendecomposition will produce tangents "in the gauge" (I forget formally what the property is, something about the specific tangent space). But there's actually no guarantees in reverse-mode what kind of (co)-tangent you will have on the outputs. If downstream code interacts with the eigenvectors in a gauge-dependent manner, then your cotangents may not be "in the gauge" at all. So if you're implementing a rule, ideally you'd have one that works for all co-tangents; if that's not possible, then you'd need to require and document that downstream code interacts with the eigenvectors only in a gauge-invariant way. In that case, then I think this new testing method would work, but this seems highly specialized to this particular problem. If we think of the tangents as being constrained to a specific manifold, I can't actually remember off the top of my head if the manifold containing the co-tangent should always be a superset of this specific manifold. It would be nice to have a second motivating application for this before adding another function to the API. In terms of design, a few notes:
|
|
Thanks for the comment! I agree this is pretty specialized -- it's just nearly all the functions I need to do AD with have this annoying gauge property. The problem isn't forward mode itself, it's the FD check (as FD generates tangents in a totally arbitrary gauge). |
|
I think for now this can be closed, as after some thought our forward sensitivities are really bad to test against FD and I wasn't testing what I thought I was here... |
This allows you to:
This can be helpful when the tangents need a particular choice of gauge (e.g. the tangents for eigenvectors) -- FD will generate tangents with an arbitrary gauge.