Skip to content

Add basic Flux integration tests using Enzyme#2968

Open
gamila-wisam wants to merge 2 commits intoEnzymeAD:mainfrom
gamila-wisam:flux-integration-test
Open

Add basic Flux integration tests using Enzyme#2968
gamila-wisam wants to merge 2 commits intoEnzymeAD:mainfrom
gamila-wisam:flux-integration-test

Conversation

@gamila-wisam
Copy link

@gamila-wisam gamila-wisam commented Feb 8, 2026

Summary

This PR adds a small set of integration tests for Flux models using Enzyme.jl, comparing
Enzyme gradients against Zygote gradients.

Details

  • Includes 3 simple models for initial testing:
    1. Dense layer
    2. Small Chain
    3. Small Conv layer
  • Uses check_approx to compare Enzyme vs Zygote gradients.
  • Flux trainable parameters are collected with Flux.trainable (replacing deprecated Flux.params).
  • Designed to be lightweight for fast iteration; additional models can be added later.

Testing

  • The tests run successfully locally with include("test/integration/Flux/runtests.jl").
  • Full CI will run the standard Enzyme test suite, including these new tests.

Motivation

Adds coverage for Flux models, ensuring Enzyme works correctly with common Flux layers.

Related issue

References FluxML/Flux.jl#2644

@github-actions
Copy link
Contributor

github-actions bot commented Feb 8, 2026

Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic main) to apply these changes.

Click here to view the suggested changes.
diff --git a/test/integration/Flux/runtests.jl b/test/integration/Flux/runtests.jl
index 6158379d..fab13466 100644
--- a/test/integration/Flux/runtests.jl
+++ b/test/integration/Flux/runtests.jl
@@ -33,7 +33,7 @@ function test_enzyme_gradients(model, x, ps, st)
     dx_zygote, dps_zygote = compute_zygote_gradient(model, x, ps, st)
 
     @test check_approx(dx, dx_zygote; atol = 1.0f-3, rtol = 1.0f-3)
-    @test check_approx(dps, dps_zygote; atol = 1.0f-3, rtol = 1.0f-3)
+    return @test check_approx(dps, dps_zygote; atol = 1.0f-3, rtol = 1.0f-3)
 end
 
 # small list of models to test

@gamila-wisam gamila-wisam reopened this Feb 8, 2026
@gamila-wisam
Copy link
Author

The CI failure seems to be in SciML tests and is unrelated to these Flux integration tests. The new tests run successfully locally using :
include("test/integration/Flux/runtests.jl")

(Dense(2, 3), randn(Float32, 2, 4)),

# small Chain
(Chain(Dense(2, 4, relu), Dense(4, 2)), randn(Float32, 2, 3)),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @CarloLucibello were there more models you wanted to test here, I know the flux ones have a bigger list iirc

# compare Enzyme gradients with Zygote gradients
function test_enzyme_gradients(model, x, ps, st)
dx, dps = compute_enzyme_gradient(model, x, ps, st)
dx_zygote, dps_zygote = compute_zygote_gradient(model, x, ps, st)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gamila-wisam with zygote broken on 1.12, can you have this test against something other than zygote [otherwise we can't compare on 1.12+]

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, what about finite-differences gradients?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure [as long as the models aren't so large that the time would be reasonable]

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I've considered that, I will try to ensure that runtime stays reasonable

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants