Skip to content

Commit 44f30b4

Browse files
committed
test(Lux): migrate to ParallelTestRunner
1 parent fa3c732 commit 44f30b4

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+1945
-2138
lines changed

.buildkite/testing.yml

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ steps:
66
- JuliaCI/julia#v1:
77
version: "{{matrix.julia}}"
88
- JuliaCI/julia-test#v1:
9-
test_args: "BACKEND_GROUP=CUDA LUX_TEST_GROUP={{matrix.group}}"
9+
test_args: "--BACKEND_GROUP=CUDA {{matrix.group}}"
1010
- JuliaCI/julia-coverage#v1:
1111
codecov: true
1212
dirs:
@@ -31,13 +31,14 @@ steps:
3131
julia:
3232
- "1.12"
3333
group:
34+
- "autodiff"
35+
- "contrib"
36+
- "helpers"
3437
- "core_layers"
35-
- "normalize_layers"
3638
- "recurrent_layers"
37-
- "autodiff"
3839
- "misc"
3940
- "reactant"
40-
- "extras"
41+
- "others"
4142

4243
# - group: ":julia: (Lux) AMD GPU"
4344
# steps:
@@ -46,7 +47,7 @@ steps:
4647
# - JuliaCI/julia#v1:
4748
# version: "{{matrix.julia}}"
4849
# - JuliaCI/julia-test#v1:
49-
# test_args: "BACKEND_GROUP=AMDGPU"
50+
# test_args: "--BACKEND_GROUP=AMDGPU"
5051
# - JuliaCI/julia-coverage#v1:
5152
# codecov: true
5253
# dirs:

.github/workflows/CI.yml

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,18 +39,17 @@ jobs:
3939
os:
4040
- ubuntu-latest
4141
test_group:
42+
- "autodiff"
4243
- "core_layers"
43-
- "normalize_layers"
4444
- "recurrent_layers"
45-
- "autodiff"
4645
- "misc"
4746
- "reactant"
48-
- "extras"
47+
- "others"
4948
uses: ./.github/workflows/CommonCI.yml
5049
with:
5150
julia_version: ${{ matrix.version }}
5251
project: "."
53-
test_args: "BACKEND_GROUP=CPU LUX_TEST_GROUP=${{ matrix.test_group }}"
52+
test_args: "--BACKEND_GROUP=CPU ${{ matrix.test_group }}"
5453
os: ${{ matrix.os }}
5554
local_dependencies: "lib/MLDataDevices,lib/WeightInitializers,lib/LuxLib,lib/LuxCore"
5655
local_test_dependencies: "lib/MLDataDevices,lib/LuxTestUtils,lib/LuxLib,lib/LuxCore"
@@ -60,18 +59,17 @@ jobs:
6059
fail-fast: false
6160
matrix:
6261
test_group:
62+
- "autodiff"
6363
- "core_layers"
64-
- "normalize_layers"
6564
- "recurrent_layers"
66-
- "autodiff"
6765
- "misc"
6866
- "reactant"
69-
- "extras"
67+
- "others"
7068
uses: ./.github/workflows/CommonCI.yml
7169
with:
7270
julia_version: "1.11"
7371
project: "."
7472
downgrade_testing: true
75-
test_args: "BACKEND_GROUP=CPU LUX_TEST_GROUP=${{ matrix.test_group }}"
73+
test_args: "--BACKEND_GROUP=CPU ${{ matrix.test_group }}"
7674
local_dependencies: "lib/MLDataDevices,lib/WeightInitializers,lib/LuxLib,lib/LuxCore"
7775
local_test_dependencies: "lib/MLDataDevices,lib/LuxTestUtils,lib/LuxLib,lib/LuxCore"

test/Project.toml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
[deps]
22
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
3+
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
34
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
45
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
56
CPUSummary = "2a0fbf3d-bb9c-48f3-b0a9-814d99fd7ab9"
7+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
68
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
79
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
810
DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8"
@@ -11,11 +13,11 @@ Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
1113
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
1214
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1315
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
14-
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
1516
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1617
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
1718
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
1819
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
20+
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
1921
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
2022
LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11"
2123
LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531"
@@ -25,11 +27,11 @@ NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
2527
Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4"
2628
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
2729
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
30+
ParallelTestRunner = "d3525ed8-44d0-4b2c-a655-542cee43accc"
2831
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
2932
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
3033
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
3134
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
32-
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
3335
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
3436
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
3537
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
@@ -41,6 +43,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
4143
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4244
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
4345
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
46+
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
4447

4548
[sources]
4649
Lux = {path = ".."}
@@ -62,7 +65,6 @@ Enzyme = "0.13.120"
6265
ExplicitImports = "1.14.2"
6366
ForwardDiff = "0.10.36, 1"
6467
Functors = "0.5"
65-
InteractiveUtils = "<0.0.1, 1"
6668
LinearAlgebra = "1.10"
6769
Logging = "1.10"
6870
LoopVectorization = "0.12.171"
@@ -75,11 +77,11 @@ NNlib = "0.9.27"
7577
Octavian = "0.3.28"
7678
OneHotArrays = "0.2.5"
7779
Optimisers = "0.4.6"
80+
ParallelTestRunner = "2.1"
7881
Pkg = "1.10"
7982
Preferences = "1.4.3"
8083
PythonCall = "0.9"
8184
Random = "1.10"
82-
ReTestItems = "1.24.0"
8385
Reactant = "0.2.205"
8486
Reexport = "1.2.2"
8587
ReverseDiff = "1.15.3"
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
using ComponentArrays, ForwardDiff, Zygote, Tracker, ReverseDiff
2+
3+
include("../shared_testsetup.jl")
4+
5+
@testset "Batched Jacobian" begin
6+
rng = StableRNG(12345)
7+
8+
@testset "$mode" for (mode, aType, dev, ongpu) in MODES
9+
models = (
10+
Chain(
11+
Conv((3, 3), 2 => 4, gelu; pad=SamePad()),
12+
Conv((3, 3), 4 => 2, gelu; pad=SamePad()),
13+
FlattenLayer(),
14+
Dense(18 => 2),
15+
),
16+
Chain(Dense(2, 4, gelu), Dense(4, 2)),
17+
)
18+
Xs = (aType(randn(rng, Float32, 3, 3, 2, 4)), aType(randn(rng, Float32, 2, 4)))
19+
20+
for (model, X) in zip(models, Xs)
21+
ps, st = dev(Lux.setup(rng, model))
22+
smodel = StatefulLuxLayer(model, ps, st)
23+
24+
J1 = allow_unstable() do
25+
ForwardDiff.jacobian(smodel, X)
26+
end
27+
28+
@testset "$(backend)" for backend in (AutoZygote(), AutoForwardDiff())
29+
J2 = allow_unstable() do
30+
batched_jacobian(smodel, backend, X)
31+
end
32+
J2_mat =
33+
mapreduce(
34+
Base.Fix1(Lux.AutoDiffInternalImpl.batched_row, J2),
35+
hcat,
36+
1:(size(J2, 1) * size(J2, 3)),
37+
)'
38+
39+
@test Array(J1) Array(J2_mat) atol = 1.0e-3 rtol = 1.0e-3
40+
41+
ps = dev(ComponentArray(cpu_device()(ps)))
42+
43+
smodel = StatefulLuxLayer(model, ps, st)
44+
45+
J3 = allow_unstable() do
46+
batched_jacobian(smodel, backend, X)
47+
end
48+
49+
@test Array(J2) Array(J3) atol = 1.0e-3 rtol = 1.0e-3
50+
end
51+
end
52+
53+
@testset "Issue #636 Chunksize Specialization" begin
54+
for N in (2, 4, 8, 11, 12, 50, 51), backend in (AutoZygote(), AutoForwardDiff())
55+
model = @compact(; potential=Dense(N => N, gelu), backend=backend) do x
56+
@return allow_unstable() do
57+
batched_jacobian(potential, backend, x)
58+
end
59+
end
60+
61+
ps, st = dev(Lux.setup(Random.default_rng(), model))
62+
63+
x = dev(randn(Float32, N, 3))
64+
@test first(model(x, ps, st)) isa AbstractArray{<:Any,3}
65+
end
66+
end
67+
68+
@testset "Simple Batched Jacobian" begin
69+
# Without any Lux bs just plain old batched jacobian
70+
ftest(x) = x .^ 2
71+
x = dev(reshape(Float32.(1:6), 2, 3))
72+
73+
Jx_true = zeros(Float32, 2, 2, 3)
74+
Jx_true[1, 1, 1] = 2
75+
Jx_true[2, 2, 1] = 4
76+
Jx_true[1, 1, 2] = 6
77+
Jx_true[2, 2, 2] = 8
78+
Jx_true[1, 1, 3] = 10
79+
Jx_true[2, 2, 3] = 12
80+
Jx_true = dev(Jx_true)
81+
82+
Jx_fdiff = allow_unstable() do
83+
batched_jacobian(ftest, AutoForwardDiff(), x)
84+
end
85+
@test Jx_fdiff Jx_true
86+
87+
Jx_zygote = allow_unstable() do
88+
batched_jacobian(ftest, AutoZygote(), x)
89+
end
90+
@test Jx_zygote Jx_true
91+
92+
fincorrect(x) = x[:, 1]
93+
x = dev(reshape(Float32.(1:6), 2, 3))
94+
95+
@test_throws AssertionError batched_jacobian(fincorrect, AutoForwardDiff(), x)
96+
@test_throws AssertionError batched_jacobian(fincorrect, AutoZygote(), x)
97+
end
98+
end
99+
end
Lines changed: 21 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
# With LuxTestUtils 1.1 we have inbuilt Enzyme.jl support in `test_gradients`. We should be
22
# able to remove this, but this file is still helpful to catch errors in a localized way.
3-
@testsetup module EnzymeTestSetup
3+
44
using LuxTestUtils, Enzyme, Zygote, Test
55
using Lux, NNlib
66
using LuxTestUtils: check_approx
7+
using ComponentArrays
8+
9+
include("../shared_testsetup.jl")
710

811
generic_loss_function(model, x, ps, st) = sum(abs2, first(model(x, ps, st)))
912

@@ -108,49 +111,27 @@ if VERSION < v"1.11-"
108111
)
109112
end
110113

111-
export generic_loss_function,
112-
compute_enzyme_gradient, compute_zygote_gradient, test_enzyme_gradients, MODELS_LIST
113-
end
114-
115-
@testitem "Enzyme Integration" setup = [EnzymeTestSetup, SharedTestSetup] tags = [
116-
:autodiff, :enzyme
117-
] timeout = 3600 skip = :(using LuxTestUtils;
118-
!LuxTestUtils.ENZYME_TESTING_ENABLED[]) begin
119-
rng = StableRNG(12345)
120-
121-
@testset "$mode" for (mode, aType, dev, ongpu) in MODES
122-
ongpu && continue
123-
124-
@testset "[$(i)] $(nameof(typeof(model)))" for (i, (model, x)) in
125-
enumerate(MODELS_LIST)
126-
display(model)
127-
128-
ps, st = dev(Lux.setup(rng, model))
129-
x = aType(x)
130-
test_enzyme_gradients(model, x, ps, st)
131-
end
132-
end
133-
end
134-
135-
@testitem "Enzyme Integration ComponentArray" setup = [EnzymeTestSetup, SharedTestSetup] timeout =
136-
3600 tags = [:autodiff, :enzyme] skip = :(using LuxTestUtils;
137-
!LuxTestUtils.ENZYME_TESTING_ENABLED[]) begin
138-
using ComponentArrays
114+
@testset "Enzyme Integration" begin
115+
if LuxTestUtils.ENZYME_TESTING_ENABLED[]
116+
rng = StableRNG(12345)
139117

140-
rng = StableRNG(12345)
118+
@testset "$mode" for (mode, aType, dev, ongpu) in MODES
119+
ongpu && continue
141120

142-
@testset "$mode" for (mode, aType, dev, ongpu) in MODES
143-
ongpu && continue
121+
@testset "[$(i)] $(nameof(typeof(model)))" for (i, (model, x)) in
122+
enumerate(MODELS_LIST)
123+
display(model)
144124

145-
@testset "[$(i)] $(nameof(typeof(model)))" for (i, (model, x)) in
146-
enumerate(MODELS_LIST)
147-
display(model)
125+
ps, st = dev(Lux.setup(rng, model))
126+
x = aType(x)
127+
test_enzyme_gradients(model, x, ps, st)
148128

149-
ps, st = Lux.setup(rng, model)
150-
ps = ComponentArray(ps)
151-
st = dev(st)
152-
x = aType(x)
153-
test_enzyme_gradients(model, x, ps, st)
129+
@testset "ComponentArray" begin
130+
ps_ca = ComponentArray(ps)
131+
st_ca = dev(st)
132+
test_enzyme_gradients(model, x, ps_ca, st_ca)
133+
end
134+
end
154135
end
155136
end
156137
end

0 commit comments

Comments
 (0)