Skip to content

Commit 29d9a57

Browse files
committed
feat: update LuxTestUtils to support 1.12
ci: test on 1.12 fix: ref negate fix: use finite differences to test ground truth
1 parent 3a63012 commit 29d9a57

File tree

19 files changed

+158
-59
lines changed

19 files changed

+158
-59
lines changed

.buildkite/testing_luxtestutils.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ steps:
2222
matrix:
2323
setup:
2424
julia:
25-
- "1.11"
25+
- "1.12"
2626

2727
env:
2828
SECRET_CODECOV_TOKEN: "jQ0BMTQgyZx7QGyU0Q2Ec7qB9mtE2q/tDu0FsfxvEG7/zOAGvXkyXrzIFFOQxvDoFcP+K2+hYZKMxicYdNqzr5wcxu505aNGN2GM3wyegAr+hO6q12bCFYx6qXzU9FLCCdeqINqn9gUSSOlGtWNFrbAlrTyz/D4Yo66TqBDzvaLL63FMnhCLaXW/zJt3hNuEAJaPY2O6Ze1rX2WZ3Y+i+s3uQ8aLImtoCJhPe8CRx+OhuYiTzGhynFfGntZ0738/1RN4gNM0S/hTC4gLE7XMVBanJpGh32rFaiDwW4zAyXKBrDkL3QA3MS1RvLTJxGJ085S16hCk0C4ddAhZCvIM9Q==;U2FsdGVkX1+bXdFeKMs5G79catOCyby2n07A2fg0FjVAvrjQLZ0yfvDS4paJiFikLkodho0khz2YALKb2Y0K6w=="

.github/workflows/CI_LuxTestUtils.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,5 @@ jobs:
1919
test:
2020
uses: ./.github/workflows/CommonCI.yml
2121
with:
22-
julia_version: "1.11"
22+
julia_version: "1.12"
2323
project: "lib/LuxTestUtils"

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ GPUArraysCore = "0.2"
109109
LinearAlgebra = "1.10"
110110
LossFunctions = "0.11.1, 1"
111111
LuxCore = "1.5.1"
112-
LuxLib = "1.15"
112+
LuxLib = "1.15.1"
113113
MLDataDevices = "1.17"
114114
MLUtils = "0.4.4"
115115
MPI = "0.20.19"
@@ -138,5 +138,5 @@ WeightInitializers = "1.3"
138138
Zygote = "0.7"
139139
julia = "1.10"
140140

141-
[workspaces]
141+
[workspace]
142142
projects = ["test", "docs"]

lib/LuxCUDA/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,5 @@ Reexport = "1"
1414
cuDNN = "1.4.3"
1515
julia = "1.10"
1616

17-
[workspaces]
17+
[workspace]
1818
projects = ["test"]

lib/LuxCore/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,5 +49,5 @@ Setfield = "1"
4949
Tracker = "0.2.36"
5050
julia = "1.10"
5151

52-
[workspaces]
52+
[workspace]
5353
projects = ["test"]

lib/LuxLib/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,5 +105,5 @@ UUIDs = "1.10"
105105
cuDNN = "1.3"
106106
julia = "1.10"
107107

108-
[workspaces]
108+
[workspace]
109109
projects = ["test"]

lib/LuxLib/test/common_ops/dense_tests.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ end
138138
end
139139

140140
@testitem "Enzyme.Forward patch: dense" tags = [:common] setup = [SharedTestSetup] skip =
141-
:(using LuxTestUtils; !LuxTestUtils.ENZYME_TESTING_ENABLED) begin
141+
:(using LuxTestUtils; !LuxTestUtils.ENZYME_TESTING_ENABLED[]) begin
142142
using LuxLib, Random, ForwardDiff, Enzyme
143143

144144
x = rand(Float32, 2, 2)
@@ -149,7 +149,7 @@ end
149149
end
150150

151151
@testitem "Enzyme rules for fused dense" tags = [:common] setup = [SharedTestSetup] skip =
152-
:(using LuxTestUtils; !LuxTestUtils.ENZYME_TESTING_ENABLED) begin
152+
:(using LuxTestUtils; !LuxTestUtils.ENZYME_TESTING_ENABLED[]) begin
153153
using LuxLib, NNlib, Zygote, Enzyme
154154

155155
# These are mostly for testing the CUDA rules since we don't enable the CUDA tests

lib/LuxLib/test/common_ops/dropout_tests.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@
3333
T(2),
3434
dims;
3535
atol=1.0f-3,
36-
rtol=1.0f-3
36+
rtol=1.0f-3,
37+
ground_truth_eltype=Nothing
3738
)
3839

3940
y, mask_, rng_ = dropout(rng, x, T(0.5), Val(false), T(2), dims)
@@ -86,7 +87,8 @@ end
8687
T(2),
8788
:;
8889
atol=1.0f-3,
89-
rtol=1.0f-3
90+
rtol=1.0f-3,
91+
ground_truth_eltype=Nothing
9092
)
9193

9294
@jet sum(first(dropout(rng, x, mask, T(0.5), Val(true), Val(true), T(2), :)))
@@ -117,7 +119,8 @@ end
117119
T(2),
118120
:;
119121
atol=1.0f-3,
120-
rtol=1.0f-3
122+
rtol=1.0f-3,
123+
ground_truth_eltype=Nothing
121124
)
122125

123126
@jet sum(first(dropout(rng, x, mask, T(0.5), Val(true), Val(false), T(2), :)))
@@ -170,7 +173,8 @@ end
170173
T(0.5),
171174
Val(true);
172175
atol=1.0f-3,
173-
rtol=1.0f-3
176+
rtol=1.0f-3,
177+
ground_truth_eltype=Nothing
174178
)
175179

176180
@jet sum(first(alpha_dropout(rng, x, T(0.5), Val(true))))

lib/LuxTestUtils/Project.toml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
name = "LuxTestUtils"
22
uuid = "ac9de150-d08f-4546-94fb-7472b5760531"
3-
version = "2.1.0"
3+
version = "2.2.0"
44
authors = ["Avik Pal <avikpal@mit.edu>"]
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
8+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
89
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
910
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1011
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
@@ -25,6 +26,7 @@ MLDataDevices = {path = "../MLDataDevices"}
2526

2627
[compat]
2728
ADTypes = "1.10"
29+
Adapt = "4.1"
2830
ArrayInterface = "7.17.1"
2931
ChainRulesCore = "1.25.1"
3032
ComponentArrays = "0.15.22"
@@ -33,13 +35,13 @@ Enzyme = "0.13.81"
3335
FiniteDiff = "2.23.1"
3436
ForwardDiff = "0.10.36, 1"
3537
Functors = "0.5"
36-
JET = "0.9.6, 0.10"
38+
JET = "0.9.6, 0.10, 0.11"
3739
MLDataDevices = "1.17"
3840
Optimisers = "0.3.4, 0.4"
3941
Pkg = "1.10"
4042
Test = "1.10"
4143
Zygote = "0.7"
4244
julia = "1.10"
4345

44-
[workspaces]
46+
[workspace]
4547
projects = ["test"]

lib/LuxTestUtils/src/LuxTestUtils.jl

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
module LuxTestUtils
22

3+
using Adapt: Adapt, adapt
34
using ArrayInterface: ArrayInterface
45
using ComponentArrays: ComponentArray, getdata, getaxes
56
using DispatchDoctor: allow_unstable
6-
using Functors: Functors
7+
using Functors: Functors, fmap
78
using MLDataDevices: cpu_device, gpu_device, get_device, get_device_type, AbstractGPUDevice
89
using Optimisers: Optimisers
910
using Pkg: PackageSpec
@@ -33,25 +34,41 @@ using Zygote: Zygote
3334
const CRC = ChainRulesCore
3435
const FD = FiniteDiff
3536

37+
const JET_TESTING_ENABLED = Ref{Bool}(false)
38+
const ENZYME_TESTING_ENABLED = Ref{Bool}(false)
39+
const ZYGOTE_TESTING_ENABLED = Ref{Bool}(false)
40+
3641
# Check if JET will work
3742
try
3843
using JET: JET, JETTestFailure, get_reports, report_call, report_opt
39-
# XXX: In 1.11, JET leads to stack overflows
40-
global JET_TESTING_ENABLED = v"1.10-" VERSION < v"1.11-"
44+
JET_TESTING_ENABLED[] = true
4145
catch err
4246
@error "`JET.jl` did not successfully precompile on $(VERSION). All `@jet` tests will \
4347
be skipped." maxlog = 1 err = err
44-
global JET_TESTING_ENABLED = false
48+
JET_TESTING_ENABLED[] = false
4549
end
4650

47-
# Check if Enzyme will work
48-
try
49-
using Enzyme: Enzyme
50-
__ftest(x) = x
51-
Enzyme.autodiff(Enzyme.Reverse, __ftest, Enzyme.Active, Enzyme.Active(2.0))
52-
global ENZYME_TESTING_ENABLED = Sys.islinux()
53-
catch err
54-
global ENZYME_TESTING_ENABLED = false
51+
# Check if Enzyme will work (only on non-prerelease versions)
52+
@static if isempty(VERSION.prerelease)
53+
try
54+
using Enzyme: Enzyme
55+
Enzyme.gradient(Enzyme.Reverse, Base.Fix1(sum, abs2), ones(Float32, 10))
56+
ENZYME_TESTING_ENABLED[] = Sys.islinux()
57+
catch err
58+
@error "`Enzyme.jl` did not successfully differentiate a simple function or \
59+
failed to load on $(VERSION). All Enzyme tests will be \
60+
skipped." maxlog = 1 err = err
61+
ENZYME_TESTING_ENABLED[] = false
62+
end
63+
end
64+
65+
function __init__()
66+
ZYGOTE_TESTING_ENABLED[] = VERSION < v"1.12-"
67+
68+
if JET_TESTING_ENABLED[]
69+
# JET doesn't work nicely on 1.11
70+
JET_TESTING_ENABLED[] = VERSION < v"1.11-" || VERSION v"1.12-"
71+
end
5572
end
5673

5774
include("package_install.jl")

0 commit comments

Comments
 (0)