Skip to content

Commit 9ebc339

Browse files
committed
test: split up a bit
1 parent 7109675 commit 9ebc339

15 files changed

+510
-484
lines changed

test/misc/doctests.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
using Lux, Documenter
2+
3+
# Some of the tests are flaky on prereleases
4+
@testset "doctests: Quality Assurance" begin
5+
doctestexpr = :(using Adapt, Lux, Random, Optimisers, Zygote, NNlib)
6+
7+
DocMeta.setdocmeta!(Lux, :DocTestSetup, doctestexpr; recursive=true)
8+
doctest(Lux; manual=false)
9+
end

test/misc/helpers/adtypes_tests.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
using ADTypes, Optimisers, Tracker, ReverseDiff, Mooncake, ComponentArrays, Enzyme
2+
3+
include("../../shared_testsetup.jl")
4+
5+
@testset "AbstractADTypes" begin
6+
function _loss_function(model, ps, st, data)
7+
y, st = model(data, ps, st)
8+
return sum(y), st, ()
9+
end
10+
11+
rng = StableRNG(12345)
12+
13+
@testset "$mode" for (mode, aType, dev, ongpu) in MODES
14+
model = Dense(3, 2)
15+
opt = Adam(0.01f0)
16+
ps, st = dev(Lux.setup(rng, model))
17+
18+
tstate = Training.TrainState(model, ps, st, opt)
19+
20+
x = aType(randn(Lux.replicate(rng), Float32, (3, 1)))
21+
22+
for ad in (AutoZygote(), AutoTracker(), AutoReverseDiff(), AutoEnzyme())
23+
ongpu && (ad isa AutoReverseDiff || ad isa AutoEnzyme) && continue
24+
!LuxTestUtils.ENZYME_TESTING_ENABLED[] && ad isa AutoEnzyme && continue
25+
!LuxTestUtils.ZYGOTE_TESTING_ENABLED[] && ad isa AutoZygote && continue
26+
27+
grads, _, _, _ = Training.compute_gradients(ad, _loss_function, x, tstate)
28+
tstate_ = Training.apply_gradients(tstate, grads)
29+
@test tstate_.step == 1
30+
@test tstate != tstate_
31+
end
32+
end
33+
end
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
using ADTypes, Optimisers, Tracker, ReverseDiff, Mooncake, ComponentArrays, Enzyme
2+
3+
include("../../shared_testsetup.jl")
4+
5+
@testset "Training API" begin
6+
mse = MSELoss()
7+
8+
rng = StableRNG(12345)
9+
10+
x_data = randn(rng, Float32, 4, 32)
11+
y_data = evalpoly.(x_data, ((1, 2, 3),)) .- evalpoly.(x_data, ((5, 2),))
12+
y_data = (y_data .- minimum(y_data)) ./ (maximum(y_data) - minimum(y_data))
13+
dataset = [(x_data[:, i], y_data[:, i]) for i in Iterators.partition(1:32, 8)]
14+
15+
@testset "$mode" for (mode, aType, dev, ongpu) in MODES
16+
model = Chain(
17+
Dense(4, 32, tanh),
18+
BatchNorm(32),
19+
Dense(32, 32, tanh),
20+
BatchNorm(32),
21+
Dense(32, 4),
22+
)
23+
dataset_ = [dev((x, y)) for (x, y) in dataset]
24+
opt = Adam(0.001f0)
25+
26+
@testset "$(ad)" for ad in (
27+
AutoZygote(), AutoTracker(), AutoReverseDiff(), AutoEnzyme(), AutoMooncake()
28+
)
29+
ongpu &&
30+
(ad isa AutoReverseDiff || ad isa AutoEnzyme || ad isa AutoMooncake) &&
31+
continue
32+
!LuxTestUtils.ENZYME_TESTING_ENABLED[] && ad isa AutoEnzyme && continue
33+
!LuxTestUtils.ZYGOTE_TESTING_ENABLED[] && ad isa AutoZygote && continue
34+
35+
function get_total_loss(model, tstate)
36+
loss = 0.0f0
37+
for (x, y) in dataset_
38+
loss += mse(model, tstate.parameters, tstate.states, (x, y))[1]
39+
end
40+
return loss
41+
end
42+
43+
@testset "compute_gradients + apply_gradients!" begin
44+
ps, st = dev(Lux.setup(rng, model))
45+
tstate = Training.TrainState(model, ps, st, opt)
46+
47+
initial_loss = get_total_loss(model, tstate)
48+
49+
for epoch in 1:1000, (x, y) in dataset_
50+
grads, loss, _, tstate = allow_unstable() do
51+
Training.compute_gradients(ad, mse, (x, y), tstate)
52+
end
53+
tstate = Training.apply_gradients!(tstate, grads)
54+
end
55+
56+
final_loss = get_total_loss(model, tstate)
57+
@test final_loss * 100 < initial_loss
58+
end
59+
60+
@testset "single_train_step!" begin
61+
ps, st = dev(Lux.setup(rng, model))
62+
tstate = Training.TrainState(model, ps, st, opt)
63+
64+
initial_loss = get_total_loss(model, tstate)
65+
66+
for epoch in 1:1000, (x, y) in dataset_
67+
grads, loss, _, tstate = allow_unstable() do
68+
Training.single_train_step!(ad, mse, (x, y), tstate)
69+
end
70+
end
71+
72+
final_loss = get_total_loss(model, tstate)
73+
@test final_loss * 100 < initial_loss
74+
end
75+
76+
@testset "single_train_step" begin
77+
ps, st = dev(Lux.setup(rng, model))
78+
tstate = Training.TrainState(model, ps, st, opt)
79+
80+
initial_loss = get_total_loss(model, tstate)
81+
82+
for epoch in 1:1000, (x, y) in dataset_
83+
grads, loss, _, tstate = allow_unstable() do
84+
Training.single_train_step(ad, mse, (x, y), tstate)
85+
end
86+
end
87+
88+
final_loss = get_total_loss(model, tstate)
89+
@test final_loss * 100 < initial_loss
90+
end
91+
92+
# Test the adjust API
93+
tstate = Optimisers.adjust(tstate, 0.1f0)
94+
@test tstate.optimizer_state.layer_1.weight.rule.eta 0.1f0
95+
96+
tstate = Optimisers.adjust(tstate; eta=0.5f0)
97+
@test tstate.optimizer_state.layer_1.weight.rule.eta 0.5f0
98+
99+
Optimisers.adjust!(tstate, 0.01f0)
100+
@test tstate.optimizer_state.layer_1.weight.rule.eta 0.01f0
101+
102+
Optimisers.adjust!(tstate; eta=0.11f0)
103+
@test tstate.optimizer_state.layer_1.weight.rule.eta 0.11f0
104+
end
105+
106+
struct AutoCustomAD <: ADTypes.AbstractADType end
107+
108+
ps, st = dev(Lux.setup(rng, model))
109+
tstate = Training.TrainState(model, ps, st, opt)
110+
111+
@test_throws ArgumentError Training.compute_gradients(
112+
AutoCustomAD(), mse, dataset_[1], tstate
113+
)
114+
end
115+
end
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
using ADTypes, Optimisers, Enzyme
2+
3+
include("../../shared_testsetup.jl")
4+
5+
@testset "Training API Enzyme Runtime Mode" begin
6+
if !LuxTestUtils.ENZYME_TESTING_ENABLED[]
7+
@test_broken false
8+
return nothing
9+
end
10+
11+
function makemodel(n)
12+
@compact(dense = Dense(n => 1; use_bias=true), b = ones(Float32, n)) do x
13+
@return dense(x .+ b)
14+
end
15+
end
16+
17+
n_samples = 20
18+
x_dim = 10
19+
y_dim = 1
20+
21+
model = makemodel(x_dim)
22+
rng = Random.default_rng()
23+
ps, st = Lux.setup(rng, model)
24+
25+
W = randn(rng, Float32, y_dim, x_dim)
26+
b = randn(rng, Float32, y_dim)
27+
28+
x_samples = randn(rng, Float32, x_dim, n_samples)
29+
y_samples = W * x_samples .+ b .+ 0.01f0 .* randn(rng, Float32, y_dim, n_samples)
30+
31+
lossfn = MSELoss()
32+
33+
function train_model!(model, ps, st, opt, nepochs::Int)
34+
tstate = Training.TrainState(model, ps, st, opt)
35+
for i in 1:nepochs
36+
grads, loss, _, tstate = Training.single_train_step!(
37+
AutoEnzyme(; mode=set_runtime_activity(Reverse)),
38+
lossfn,
39+
(x_samples, y_samples),
40+
tstate,
41+
)
42+
end
43+
return tstate.model, tstate.parameters, tstate.states
44+
end
45+
46+
initial_loss = lossfn(first(model(x_samples, ps, st)), y_samples)
47+
48+
model, ps, st = train_model!(model, ps, st, Descent(0.01f0), 10000)
49+
50+
final_loss = lossfn(first(model(x_samples, ps, st)), y_samples)
51+
52+
@test final_loss * 100 < initial_loss
53+
end
54+
55+
@testset "Enzyme: Invalidate Cache on State Update" begin
56+
if !LuxTestUtils.ENZYME_TESTING_ENABLED[]
57+
@test_broken false
58+
return nothing
59+
end
60+
61+
mse = MSELoss()
62+
63+
function mse2(model, ps, st, (x, y))
64+
z, st = model(x, ps, st)
65+
return sum(abs2, z .- y), st, ()
66+
end
67+
68+
rng = StableRNG(12345)
69+
70+
model = Chain(Dense(4 => 3), VariationalHiddenDropout(0.5f0), Dense(3 => 4))
71+
ps, st = Lux.setup(rng, model)
72+
x = randn(rng, Float32, 4, 32)
73+
opt = Adam(0.001f0)
74+
75+
tstate = Training.TrainState(model, ps, st, opt)
76+
77+
_, _, _, tstate_new = Training.compute_gradients(AutoEnzyme(), mse, (x, x), tstate)
78+
79+
@test tstate_new.states !== tstate.states
80+
81+
model = Chain(Dense(4 => 3), Dense(3 => 4))
82+
ps, st = Lux.setup(rng, model)
83+
84+
tstate = Training.TrainState(model, ps, st, opt)
85+
86+
_, _, _, tstate_new = Training.compute_gradients(AutoEnzyme(), mse, (x, x), tstate)
87+
88+
@test @inferred(Training.compute_gradients(AutoEnzyme(), mse, (x, x), tstate_new)) isa
89+
Any
90+
91+
_, _, _, tstate_new2 = Training.compute_gradients(
92+
AutoEnzyme(), mse2, (x, x), tstate_new
93+
)
94+
@test hasfield(typeof(tstate_new2.cache.extras), :forward)
95+
@test hasfield(typeof(tstate_new2.cache.extras), :reverse)
96+
97+
rng = StableRNG(12345)
98+
99+
model = Chain(Dense(4 => 3), VariationalHiddenDropout(0.5f0), Dense(3 => 4))
100+
ps, st = Lux.setup(rng, model)
101+
x = randn(rng, Float32, 4, 32)
102+
opt = Adam(0.001f0)
103+
104+
tstate = Training.TrainState(model, ps, st, opt)
105+
106+
_, _, _, tstate_new = Training.compute_gradients(AutoEnzyme(), mse, (x, x), tstate)
107+
108+
@test tstate_new.states !== tstate.states
109+
110+
model = Chain(Dense(4 => 3), Dense(3 => 4))
111+
ps, st = Lux.setup(rng, model)
112+
113+
tstate = Training.TrainState(model, ps, st, opt)
114+
115+
_, _, _, tstate_new = Training.compute_gradients(AutoEnzyme(), mse, (x, x), tstate)
116+
117+
@test @inferred(Training.compute_gradients(AutoEnzyme(), mse, (x, x), tstate_new)) isa
118+
Any
119+
120+
_, _, _, tstate_new2 = Training.compute_gradients(
121+
AutoEnzyme(), mse2, (x, x), tstate_new
122+
)
123+
@test hasfield(typeof(tstate_new2.cache.extras), :forward)
124+
@test hasfield(typeof(tstate_new2.cache.extras), :reverse)
125+
end
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
using ADTypes, Optimisers, ComponentArrays, ForwardDiff
2+
3+
include("../../shared_testsetup.jl")
4+
5+
@testset "Training API ForwardDiff" begin
6+
mse = MSELoss()
7+
8+
rng = StableRNG(12345)
9+
10+
x_data = randn(rng, Float32, 4, 32)
11+
y_data = evalpoly.(x_data, ((1, 2, 3),)) .- evalpoly.(x_data, ((5, 2),))
12+
y_data = (y_data .- minimum(y_data)) ./ (maximum(y_data) - minimum(y_data))
13+
dataset = [(x_data[:, i], y_data[:, i]) for i in Iterators.partition(1:32, 8)]
14+
15+
model = Chain(
16+
Dense(4, 32, tanh), BatchNorm(32), Dense(32, 32, tanh), BatchNorm(32), Dense(32, 4)
17+
)
18+
19+
dataset_ = [(x, y) for (x, y) in dataset]
20+
opt = Adam(0.001f0)
21+
22+
ps, st = Lux.setup(rng, model)
23+
tstate = Training.TrainState(model, ComponentVector(ps), st, opt)
24+
25+
initial_loss = first(
26+
mse(model, tstate.parameters, Lux.testmode(tstate.states), dataset_[1])
27+
)
28+
29+
for epoch in 1:100, (x, y) in dataset_
30+
grads, loss, _, tstate = allow_unstable() do
31+
Training.compute_gradients(AutoForwardDiff(), mse, (x, y), tstate)
32+
end
33+
tstate = Training.apply_gradients!(tstate, grads)
34+
end
35+
36+
for epoch in 1:100, (x, y) in dataset_
37+
grads, loss, _, tstate = allow_unstable() do
38+
Training.single_train_step!(AutoForwardDiff(), mse, (x, y), tstate)
39+
end
40+
end
41+
42+
for epoch in 1:100, (x, y) in dataset_
43+
grads, loss, _, tstate = allow_unstable() do
44+
Training.single_train_step(AutoForwardDiff(), mse, (x, y), tstate)
45+
end
46+
end
47+
48+
final_loss = first(mse(model, tstate.parameters, tstate.states, dataset_[1]))
49+
50+
@test final_loss * 50 < initial_loss
51+
end

0 commit comments

Comments
 (0)