Skip to content

Commit ecb5e41

Browse files
committed
feat: update LuxTestUtils to support 1.12
ci: test on 1.12 fix: ref negate fix: use finite differences to test ground truth
1 parent 5073e55 commit ecb5e41

File tree

11 files changed

+76
-36
lines changed

11 files changed

+76
-36
lines changed

.buildkite/testing_luxtestutils.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ steps:
2222
matrix:
2323
setup:
2424
julia:
25-
- "1.11"
25+
- "1.12"
2626

2727
env:
2828
SECRET_CODECOV_TOKEN: "jQ0BMTQgyZx7QGyU0Q2Ec7qB9mtE2q/tDu0FsfxvEG7/zOAGvXkyXrzIFFOQxvDoFcP+K2+hYZKMxicYdNqzr5wcxu505aNGN2GM3wyegAr+hO6q12bCFYx6qXzU9FLCCdeqINqn9gUSSOlGtWNFrbAlrTyz/D4Yo66TqBDzvaLL63FMnhCLaXW/zJt3hNuEAJaPY2O6Ze1rX2WZ3Y+i+s3uQ8aLImtoCJhPe8CRx+OhuYiTzGhynFfGntZ0738/1RN4gNM0S/hTC4gLE7XMVBanJpGh32rFaiDwW4zAyXKBrDkL3QA3MS1RvLTJxGJ085S16hCk0C4ddAhZCvIM9Q==;U2FsdGVkX1+bXdFeKMs5G79catOCyby2n07A2fg0FjVAvrjQLZ0yfvDS4paJiFikLkodho0khz2YALKb2Y0K6w=="

.github/workflows/CI_LuxTestUtils.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,5 @@ jobs:
1919
test:
2020
uses: ./.github/workflows/CommonCI.yml
2121
with:
22-
julia_version: "1.11"
22+
julia_version: "1.12"
2323
project: "lib/LuxTestUtils"

lib/LuxLib/test/common_ops/dense_tests.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ end
138138
end
139139

140140
@testitem "Enzyme.Forward patch: dense" tags = [:common] setup = [SharedTestSetup] skip =
141-
:(using LuxTestUtils; !LuxTestUtils.ENZYME_TESTING_ENABLED) begin
141+
:(using LuxTestUtils; !LuxTestUtils.ENZYME_TESTING_ENABLED[]) begin
142142
using LuxLib, Random, ForwardDiff, Enzyme
143143

144144
x = rand(Float32, 2, 2)
@@ -149,7 +149,7 @@ end
149149
end
150150

151151
@testitem "Enzyme rules for fused dense" tags = [:common] setup = [SharedTestSetup] skip =
152-
:(using LuxTestUtils; !LuxTestUtils.ENZYME_TESTING_ENABLED) begin
152+
:(using LuxTestUtils; !LuxTestUtils.ENZYME_TESTING_ENABLED[]) begin
153153
using LuxLib, NNlib, Zygote, Enzyme
154154

155155
# These are mostly for testing the CUDA rules since we don't enable the CUDA tests

lib/LuxTestUtils/Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
name = "LuxTestUtils"
22
uuid = "ac9de150-d08f-4546-94fb-7472b5760531"
3-
version = "2.1.0"
3+
version = "2.2.0"
44
authors = ["Avik Pal <avikpal@mit.edu>"]
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
8+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
89
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
910
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1011
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
@@ -25,6 +26,7 @@ MLDataDevices = {path = "../MLDataDevices"}
2526

2627
[compat]
2728
ADTypes = "1.10"
29+
Adapt = "4.1"
2830
ArrayInterface = "7.17.1"
2931
ChainRulesCore = "1.25.1"
3032
ComponentArrays = "0.15.22"
@@ -33,7 +35,7 @@ Enzyme = "0.13.81"
3335
FiniteDiff = "2.23.1"
3436
ForwardDiff = "0.10.36, 1"
3537
Functors = "0.5"
36-
JET = "0.9.6, 0.10"
38+
JET = "0.9.6, 0.10, 0.11"
3739
MLDataDevices = "1.17"
3840
Optimisers = "0.3.4, 0.4"
3941
Pkg = "1.10"

lib/LuxTestUtils/src/LuxTestUtils.jl

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
module LuxTestUtils
22

3+
using Adapt: Adapt, adapt
34
using ArrayInterface: ArrayInterface
45
using ComponentArrays: ComponentArray, getdata, getaxes
56
using DispatchDoctor: allow_unstable
6-
using Functors: Functors
7+
using Functors: Functors, fmap
78
using MLDataDevices: cpu_device, gpu_device, get_device, get_device_type, AbstractGPUDevice
89
using Optimisers: Optimisers
910
using Pkg: PackageSpec
@@ -33,25 +34,41 @@ using Zygote: Zygote
3334
const CRC = ChainRulesCore
3435
const FD = FiniteDiff
3536

37+
const JET_TESTING_ENABLED = Ref{Bool}(false)
38+
const ENZYME_TESTING_ENABLED = Ref{Bool}(false)
39+
const ZYGOTE_TESTING_ENABLED = Ref{Bool}(false)
40+
3641
# Check if JET will work
3742
try
3843
using JET: JET, JETTestFailure, get_reports, report_call, report_opt
39-
# XXX: In 1.11, JET leads to stack overflows
40-
global JET_TESTING_ENABLED = v"1.10-" VERSION < v"1.11-"
44+
JET_TESTING_ENABLED[] = true
4145
catch err
4246
@error "`JET.jl` did not successfully precompile on $(VERSION). All `@jet` tests will \
4347
be skipped." maxlog = 1 err = err
44-
global JET_TESTING_ENABLED = false
48+
JET_TESTING_ENABLED[] = false
4549
end
4650

47-
# Check if Enzyme will work
48-
try
49-
using Enzyme: Enzyme
50-
__ftest(x) = x
51-
Enzyme.autodiff(Enzyme.Reverse, __ftest, Enzyme.Active, Enzyme.Active(2.0))
52-
global ENZYME_TESTING_ENABLED = Sys.islinux()
53-
catch err
54-
global ENZYME_TESTING_ENABLED = false
51+
# Check if Enzyme will work (only on non-prerelease versions)
52+
@static if isempty(VERSION.prerelease)
53+
try
54+
using Enzyme: Enzyme
55+
Enzyme.gradient(Enzyme.Reverse, Base.Fix1(sum, abs2), ones(Float32, 10))
56+
ENZYME_TESTING_ENABLED[] = Sys.islinux()
57+
catch err
58+
@error "`Enzyme.jl` did not successfully differentiate a simple function or \
59+
failed to load on $(VERSION). All Enzyme tests will be \
60+
skipped." maxlog = 1 err = err
61+
ENZYME_TESTING_ENABLED[] = false
62+
end
63+
end
64+
65+
function __init__()
66+
ZYGOTE_TESTING_ENABLED[] = VERSION < v"1.12-"
67+
68+
if JET_TESTING_ENABLED[]
69+
# JET doesn't work nicely on 1.11
70+
JET_TESTING_ENABLED[] = VERSION < v"1.11-" || VERSION v"1.12-"
71+
end
5572
end
5673

5774
include("package_install.jl")

lib/LuxTestUtils/src/autodiff.jl

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,39 @@
1+
# taken from implementation in Lux.jl
2+
struct LuxEltypeAdaptor{T} end
3+
4+
(l::LuxEltypeAdaptor)(x) = fmap(adapt(l), x)
5+
function (l::LuxEltypeAdaptor)(x::AbstractArray{T}) where {T}
6+
return isbitstype(T) ? adapt(l, x) : map(adapt(l), x)
7+
end
8+
9+
function Adapt.adapt_storage(
10+
::LuxEltypeAdaptor{T}, x::AbstractArray{<:AbstractFloat}
11+
) where {T<:AbstractFloat}
12+
return convert(AbstractArray{T}, x)
13+
end
14+
15+
function Adapt.adapt_storage(
16+
::LuxEltypeAdaptor{T}, x::AbstractArray{<:Complex{<:AbstractFloat}}
17+
) where {T<:AbstractFloat}
18+
return convert(AbstractArray{Complex{T}}, x)
19+
end
20+
121
struct Constant{T}
222
val::T
323
end
424

5-
# Zygote.jl on CPU
25+
# FiniteDiff.jl on CPU
626
function ground_truth_gradient(f, args...)
727
cdev = cpu_device()
28+
f64 = LuxEltypeAdaptor{Float64}()
829
f_cpu = try
9-
cdev(f)
30+
f64(cdev(f))
1031
catch err
1132
@error "Encountered error while moving $(f) to CPU. Skipping movement... This can \
1233
be fixed by defining overloads using ConstructionBase.jl" err
1334
f
1435
end
15-
return gradient(f_cpu, AutoZygote(), map(cdev, args)...)
36+
return gradient(f_cpu, AutoFiniteDiff(), map(f64, map(cdev, args))...)
1637
end
1738

1839
# Zygote.jl
@@ -31,8 +52,9 @@ function gradient(f::F, ::AutoEnzyme{Nothing}, args...) where {F}
3152
end
3253

3354
function gradient(f::F, ad::AutoEnzyme{<:Enzyme.ReverseMode}, args...) where {F}
34-
!ENZYME_TESTING_ENABLED &&
55+
if !ENZYME_TESTING_ENABLED[]
3556
return ntuple(Returns(GradientComputationSkipped()), length(args))
57+
end
3658

3759
args_activity = map(args) do x
3860
needs_gradient(x) && return Enzyme.Duplicated(x, Enzyme.make_zero(x))
@@ -79,13 +101,13 @@ end
79101
"""
80102
test_gradients(f, args...; skip_backends=[], broken_backends=[], kwargs...)
81103
82-
Test the gradients of `f` with respect to `args` using the specified backends.
104+
Test the gradients of `f` with respect to `args` using the specified backends. The ground
105+
truth gradients are computed using FiniteDiff.jl on CPU.
83106
84107
| Backend | ADType | CPU | GPU | Notes |
85108
|:-------------- |:------------------- |:--- |:--- |:----------------- |
86109
| Zygote.jl | `AutoZygote()` | ✔ | ✔ | |
87110
| ForwardDiff.jl | `AutoForwardDiff()` | ✔ | ✖ | `len ≤ 32` |
88-
| FiniteDiff.jl | `AutoFiniteDiff()` | ✔ | ✖ | `len ≤ 32` |
89111
| Enzyme.jl | `AutoEnzyme()` | ✔ | ✖ | Only Reverse Mode |
90112
91113
## Arguments
@@ -156,9 +178,8 @@ function test_gradients(
156178
push!(backends, AutoZygote())
157179
if !on_gpu
158180
total_length 32 && push!(backends, AutoForwardDiff())
159-
total_length 32 && push!(backends, AutoFiniteDiff())
160181
# TODO: Move Enzyme out of here once it supports GPUs
161-
if enable_enzyme_reverse_mode || ENZYME_TESTING_ENABLED
182+
if enable_enzyme_reverse_mode || ENZYME_TESTING_ENABLED[]
162183
mode = if enzyme_set_runtime_activity
163184
Enzyme.set_runtime_activity(Enzyme.Reverse)
164185
else

lib/LuxTestUtils/src/jet.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ Test Broken
5353
```
5454
"""
5555
macro jet(expr, args...)
56-
!JET_TESTING_ENABLED && return :()
56+
!JET_TESTING_ENABLED[] && return :()
5757

5858
all_args, call_extras, opt_extras = [], [], []
5959
target_modules_set = false

test/enzyme_tests.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ end
127127
ps, st = dev(Lux.setup(rng, model))
128128
x = aType(x)
129129

130-
if LuxTestUtils.ENZYME_TESTING_ENABLED
130+
if LuxTestUtils.ENZYME_TESTING_ENABLED[]
131131
test_enzyme_gradients(model, x, ps, st)
132132
else
133133
@test_broken false
@@ -154,7 +154,7 @@ end
154154
st = dev(st)
155155
x = aType(x)
156156

157-
if LuxTestUtils.ENZYME_TESTING_ENABLED
157+
if LuxTestUtils.ENZYME_TESTING_ENABLED[]
158158
test_enzyme_gradients(model, x, ps, st)
159159
else
160160
@test_broken false

test/helpers/loss_tests.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
∂x1 = ForwardDiff.derivative(Base.Fix2(LuxOps.xlogy, 3.0), 2.0)
2323
∂y1 = ForwardDiff.derivative(Base.Fix1(LuxOps.xlogy, 2.0), 3.0)
2424
∂x2, ∂y2 = Zygote.gradient(LuxOps.xlogy, 2.0, 3.0)
25-
if LuxTestUtils.ENZYME_TESTING_ENABLED
25+
if LuxTestUtils.ENZYME_TESTING_ENABLED[]
2626
((∂x3, ∂y3),) = Enzyme.autodiff(
2727
Enzyme.Reverse, LuxOps.xlogy, Active, Active(2.0), Active(3.0)
2828
)
@@ -38,7 +38,7 @@
3838
@test @inferred(LuxOps.xlogy(0, 1)) isa Number
3939
@jet LuxOps.xlogy(2, 3)
4040

41-
if LuxTestUtils.ENZYME_TESTING_ENABLED
41+
if LuxTestUtils.ENZYME_TESTING_ENABLED[]
4242
@test @inferred(
4343
Enzyme.autodiff(Enzyme.Reverse, LuxOps.xlogy, Active, Active(2.0), Active(3.0))
4444
) isa Any

test/helpers/training_tests.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ end
4040

4141
for ad in (AutoZygote(), AutoTracker(), AutoReverseDiff(), AutoEnzyme())
4242
ongpu && (ad isa AutoReverseDiff || ad isa AutoEnzyme) && continue
43-
!LuxTestUtils.ENZYME_TESTING_ENABLED && ad isa AutoEnzyme && continue
43+
!LuxTestUtils.ENZYME_TESTING_ENABLED[] && ad isa AutoEnzyme && continue
4444

4545
grads, _, _, _ = Training.compute_gradients(ad, _loss_function, x, tstate)
4646
tstate_ = Training.apply_gradients(tstate, grads)
@@ -80,7 +80,7 @@ end
8080
ongpu &&
8181
(ad isa AutoReverseDiff || ad isa AutoEnzyme || ad isa AutoMooncake) &&
8282
continue
83-
!LuxTestUtils.ENZYME_TESTING_ENABLED && ad isa AutoEnzyme && continue
83+
!LuxTestUtils.ENZYME_TESTING_ENABLED[] && ad isa AutoEnzyme && continue
8484

8585
function get_total_loss(model, tstate)
8686
loss = 0.0f0
@@ -215,7 +215,7 @@ end
215215
end
216216

217217
@testitem "Training API Enzyme Runtime Mode" setup = [SharedTestSetup] tags = [:misc] skip =
218-
:(using LuxTestUtils; !LuxTestUtils.ENZYME_TESTING_ENABLED) begin
218+
:(using LuxTestUtils; !LuxTestUtils.ENZYME_TESTING_ENABLED[]) begin
219219
using Lux, Random, Enzyme, Optimisers
220220

221221
function makemodel(n)
@@ -264,7 +264,7 @@ end
264264

265265
@testitem "Enzyme: Invalidate Cache on State Update" setup = [SharedTestSetup] tags = [
266266
:misc
267-
] skip = :(using LuxTestUtils; !LuxTestUtils.ENZYME_TESTING_ENABLED) begin
267+
] skip = :(using LuxTestUtils; !LuxTestUtils.ENZYME_TESTING_ENABLED[]) begin
268268
using ADTypes, Optimisers
269269

270270
mse = MSELoss()

0 commit comments

Comments
 (0)