Skip to content

Commit ae6be50

Browse files
authored
test(LuxLib): migrate testing to v1.12 (#1633)
1 parent 97b16dd commit ae6be50

File tree

15 files changed

+97
-73
lines changed

15 files changed

+97
-73
lines changed

.buildkite/testing_luxlib.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ steps:
2626
matrix:
2727
setup:
2828
julia:
29-
- "1.11"
29+
- "1.12"
3030
group:
3131
- "common"
3232
- "normalization"

.github/workflows/CI_LuxLib.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ jobs:
4040
loopvec: "false"
4141
uses: ./.github/workflows/CommonCI.yml
4242
with:
43-
julia_version: "1.11"
43+
julia_version: "1.12"
4444
project: "lib/LuxLib"
4545
test_args: "BACKEND_GROUP=cpu LUXLIB_TEST_GROUP=${{ matrix.test_group }} LUXLIB_BLAS_BACKEND=${{ matrix.blas_backend }} LUXLIB_LOAD_LOOPVEC=${{ matrix.loopvec }}"
4646

lib/LuxLib/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "LuxLib"
22
uuid = "82251201-b29d-42c6-8e01-566dec8acb11"
3-
version = "1.15.1"
3+
version = "1.15.2"
44
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
55

66
[deps]

lib/LuxLib/src/api/groupnorm.jl

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,19 +40,29 @@ end
4040
function assert_valid_groupnorm_arguments(
4141
x::AbstractArray{T,N}, scale, bias, groups
4242
) where {T,N}
43-
@assert length(scale) == length(bias) == size(x, N - 1) "Length of `scale` and `bias` must \
44-
be equal to the number of \
45-
channels ((N - 1) dim of the \
46-
input array)."
43+
if length(scale) !== length(bias) || length(bias) !== size(x, N - 1)
44+
throw(
45+
ArgumentError(
46+
"Length of `scale` and `bias` must be equal to the number of channels \
47+
((N - 1) dim of the input array)."
48+
),
49+
)
50+
end
4751
assert_valid_groupnorm_arguments(x, nothing, nothing, groups)
4852
return nothing
4953
end
5054

5155
function assert_valid_groupnorm_arguments(
5256
x::AbstractArray{T,N}, ::Nothing, ::Nothing, groups::Int
5357
) where {T,N}
54-
@assert size(x, N - 1) % groups == 0 "Number of channels $(size(x, N - 1)) must be \
55-
divisible by the number of groups $groups."
58+
if size(x, N - 1) % groups != 0
59+
throw(
60+
ArgumentError(
61+
"Number of channels $(size(x, N - 1)) must be divisible by the number of \
62+
groups $groups."
63+
),
64+
)
65+
end
5666
return nothing
5767
end
5868

lib/LuxLib/test/common_ops/activation_tests.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,14 @@
5252
@test_gradients(apply_act_fast, f, x; atol, rtol)
5353
@test_gradients(apply_act_fast2, f, x; atol, rtol)
5454

55-
∂x1 = Zygote.gradient(apply_act, f, x)[2]
56-
∂x2 = Zygote.gradient(apply_act_fast, f, x)[2]
57-
∂x3 = Zygote.gradient(apply_act_fast2, f, x)[2]
55+
if LuxTestUtils.ZYGOTE_TESTING_ENABLED[]
56+
∂x1 = Zygote.gradient(apply_act, f, x)[2]
57+
∂x2 = Zygote.gradient(apply_act_fast, f, x)[2]
58+
∂x3 = Zygote.gradient(apply_act_fast2, f, x)[2]
5859

59-
@test ∂x1 ∂x2 atol = atol rtol = rtol
60-
@test ∂x1 ∂x3 atol = atol rtol = rtol
60+
@test ∂x1 ∂x2 atol = atol rtol = rtol
61+
@test ∂x1 ∂x3 atol = atol rtol = rtol
62+
end
6163
end
6264
end
6365
end

lib/LuxLib/test/common_ops/attention_tests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
@testitem "Scaled Dot Product Attention" tags = [:misc] setup = [SharedTestSetup] begin
2-
using LuxLib, Reactant, NNlib, Random, MLDataDevices, Zygote, Enzyme, Statistics
2+
using LuxLib, Reactant, NNlib, Random, MLDataDevices, Enzyme, Statistics
33

44
@testset "$mode" for (mode, aType, ongpu, fp64) in MODES
55
@testset "Different Batch Sizes" begin

lib/LuxLib/test/common_ops/bias_act_tests.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -79,14 +79,16 @@
7979
soft_fail=fp16 ? [AutoFiniteDiff()] : []
8080
)
8181

82-
∂x1, ∂b1 = Zygote.gradient(__Fix1(bias_act_loss1, act), x, b)
83-
∂x2, ∂b2 = Zygote.gradient(__Fix1(bias_act_loss2, act), x, b)
84-
∂x3, ∂b3 = Zygote.gradient(__Fix1(bias_act_loss3, act), x, b)
85-
86-
@test ∂x1 ∂x2 atol = atol rtol = rtol
87-
@test ∂x1 ∂x3 atol = atol rtol = rtol
88-
@test ∂b1 ∂b2 atol = atol rtol = rtol
89-
@test ∂b1 ∂b3 atol = atol rtol = rtol
82+
if LuxTestUtils.ZYGOTE_TESTING_ENABLED[]
83+
∂x1, ∂b1 = Zygote.gradient(__Fix1(bias_act_loss1, act), x, b)
84+
∂x2, ∂b2 = Zygote.gradient(__Fix1(bias_act_loss2, act), x, b)
85+
∂x3, ∂b3 = Zygote.gradient(__Fix1(bias_act_loss3, act), x, b)
86+
87+
@test ∂x1 ∂x2 atol = atol rtol = rtol
88+
@test ∂x1 ∂x3 atol = atol rtol = rtol
89+
@test ∂b1 ∂b2 atol = atol rtol = rtol
90+
@test ∂b1 ∂b3 atol = atol rtol = rtol
91+
end
9092
end
9193
end
9294
end

lib/LuxLib/test/common_ops/conv_tests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
@testsetup module ConvSetup
2-
using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib
2+
using LuxLib, LuxTestUtils, Random, Test, NNlib
33

44
expand(_, i::Tuple) = i
55
expand(N, i::Integer) = ntuple(_ -> i, N)

lib/LuxLib/test/common_ops/dense_tests.jl

Lines changed: 40 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,18 @@ function run_dense_testing(Tw, Tx, M, N, hasbias, activation, aType, mode, ongpu
3939
y_zyg = fused_dense_bias_activation(activation, w, x, bias)
4040
@test y_simple y_zyg atol = atol rtol = rtol
4141

42-
_, ∂w_true, ∂x_true, ∂b_true = Zygote.gradient(
43-
sum dense_simple, activation, w, x, bias
44-
)
45-
_, ∂w_zyg, ∂x_zyg, ∂b_zyg = Zygote.gradient(
46-
sum fused_dense_bias_activation, activation, w, x, bias
47-
)
48-
@test ∂w_true ∂w_zyg atol = atol rtol = rtol
49-
@test ∂x_true ∂x_zyg atol = atol rtol = rtol
50-
if bias !== nothing
51-
@test ∂b_true ∂b_zyg atol = atol rtol = rtol
42+
if LuxTestUtils.ZYGOTE_TESTING_ENABLED[]
43+
_, ∂w_true, ∂x_true, ∂b_true = Zygote.gradient(
44+
sum dense_simple, activation, w, x, bias
45+
)
46+
_, ∂w_zyg, ∂x_zyg, ∂b_zyg = Zygote.gradient(
47+
sum fused_dense_bias_activation, activation, w, x, bias
48+
)
49+
@test ∂w_true ∂w_zyg atol = atol rtol = rtol
50+
@test ∂x_true ∂x_zyg atol = atol rtol = rtol
51+
if bias !== nothing
52+
@test ∂b_true ∂b_zyg atol = atol rtol = rtol
53+
end
5254
end
5355
end
5456

@@ -201,35 +203,39 @@ end
201203
b_enz,
202204
)
203205

204-
_, pb_f = Zygote.pullback(fused_dense_bias_activation, act, weight, x, b)
205-
_, dweight_zyg, dx_zyg, db_zyg = pb_f(dy)
206+
if LuxTestUtils.ZYGOTE_TESTING_ENABLED[]
207+
_, pb_f = Zygote.pullback(fused_dense_bias_activation, act, weight, x, b)
208+
_, dweight_zyg, dx_zyg, db_zyg = pb_f(dy)
206209

207-
@test dweight dweight_zyg atol = 1.0e-3 rtol = 1.0e-3
208-
@test dx dx_zyg atol = 1.0e-3 rtol = 1.0e-3
209-
if hasbias
210-
@test db db_zyg atol = 1.0e-3 rtol = 1.0e-3
210+
@test dweight dweight_zyg atol = 1.0e-3 rtol = 1.0e-3
211+
@test dx dx_zyg atol = 1.0e-3 rtol = 1.0e-3
212+
if hasbias
213+
@test db db_zyg atol = 1.0e-3 rtol = 1.0e-3
214+
end
211215
end
212216

213217
(act === identity && hasbias) || continue
214218

215-
dweight .= 0
216-
dx .= 0
217-
db .= 0
218-
Enzyme.autodiff(
219-
Reverse,
220-
matmuladd!,
221-
Duplicated(y, copy(dy)),
222-
Duplicated(weight, dweight),
223-
Duplicated(x, dx),
224-
b_enz,
225-
)
226-
227-
_, pb_f = Zygote.pullback(LuxLib.Impl.matmuladd, weight, x, b)
228-
dweight_zyg, dx_zyg, db_zyg = pb_f(dy)
229-
230-
@test dweight dweight_zyg atol = 1.0e-3 rtol = 1.0e-3
231-
@test dx dx_zyg atol = 1.0e-3 rtol = 1.0e-3
232-
@test db db_zyg atol = 1.0e-3 rtol = 1.0e-3
219+
if LuxTestUtils.ZYGOTE_TESTING_ENABLED[]
220+
dweight .= 0
221+
dx .= 0
222+
db .= 0
223+
Enzyme.autodiff(
224+
Reverse,
225+
matmuladd!,
226+
Duplicated(y, copy(dy)),
227+
Duplicated(weight, dweight),
228+
Duplicated(x, dx),
229+
b_enz,
230+
)
231+
232+
_, pb_f = Zygote.pullback(LuxLib.Impl.matmuladd, weight, x, b)
233+
dweight_zyg, dx_zyg, db_zyg = pb_f(dy)
234+
235+
@test dweight dweight_zyg atol = 1.0e-3 rtol = 1.0e-3
236+
@test dx dx_zyg atol = 1.0e-3 rtol = 1.0e-3
237+
@test db db_zyg atol = 1.0e-3 rtol = 1.0e-3
238+
end
233239
end
234240
end
235241
end

lib/LuxLib/test/normalization/batchnorm_tests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ function run_batchnorm_testing(gen_f, T, sz, training, affine, track_stats, act,
7373
end
7474

7575
# Check the rrules
76-
if is_training(training)
76+
if is_training(training) && LuxTestUtils.ZYGOTE_TESTING_ENABLED[]
7777
_f =
7878
(args...) ->
7979
sum(first(batchnorm(args..., rm, rv, training, act, T(0.9), epsilon)))

0 commit comments

Comments
 (0)