Add basic Flux integration tests using Enzyme#2968
Add basic Flux integration tests using Enzyme#2968gamila-wisam wants to merge 2 commits intoEnzymeAD:mainfrom
Conversation
|
Your PR requires formatting changes to meet the project's style guidelines. Click here to view the suggested changes.diff --git a/test/integration/Flux/runtests.jl b/test/integration/Flux/runtests.jl
index 6158379d..fab13466 100644
--- a/test/integration/Flux/runtests.jl
+++ b/test/integration/Flux/runtests.jl
@@ -33,7 +33,7 @@ function test_enzyme_gradients(model, x, ps, st)
dx_zygote, dps_zygote = compute_zygote_gradient(model, x, ps, st)
@test check_approx(dx, dx_zygote; atol = 1.0f-3, rtol = 1.0f-3)
- @test check_approx(dps, dps_zygote; atol = 1.0f-3, rtol = 1.0f-3)
+ return @test check_approx(dps, dps_zygote; atol = 1.0f-3, rtol = 1.0f-3)
end
# small list of models to test |
|
The CI failure seems to be in SciML tests and is unrelated to these Flux integration tests. The new tests run successfully locally using : |
| (Dense(2, 3), randn(Float32, 2, 4)), | ||
|
|
||
| # small Chain | ||
| (Chain(Dense(2, 4, relu), Dense(4, 2)), randn(Float32, 2, 3)), |
There was a problem hiding this comment.
cc @CarloLucibello were there more models you wanted to test here, I know the flux ones have a bigger list iirc
| # compare Enzyme gradients with Zygote gradients | ||
| function test_enzyme_gradients(model, x, ps, st) | ||
| dx, dps = compute_enzyme_gradient(model, x, ps, st) | ||
| dx_zygote, dps_zygote = compute_zygote_gradient(model, x, ps, st) |
There was a problem hiding this comment.
@gamila-wisam with zygote broken on 1.12, can you have this test against something other than zygote [otherwise we can't compare on 1.12+]
There was a problem hiding this comment.
Sure, what about finite-differences gradients?
There was a problem hiding this comment.
sure [as long as the models aren't so large that the time would be reasonable]
There was a problem hiding this comment.
Yes I've considered that, I will try to ensure that runtime stays reasonable
Summary
This PR adds a small set of integration tests for Flux models using Enzyme.jl, comparing
Enzyme gradients against Zygote gradients.
Details
check_approxto compare Enzyme vs Zygote gradients.Flux.trainable(replacing deprecatedFlux.params).Testing
include("test/integration/Flux/runtests.jl").Motivation
Adds coverage for Flux models, ensuring Enzyme works correctly with common Flux layers.
Related issue
References FluxML/Flux.jl#2644