Skip to content

Commit 3db909e

Browse files
committed
test(LuxCore): use ParallelTestRunner
1 parent 7900af5 commit 3db909e

File tree

8 files changed

+444
-432
lines changed

8 files changed

+444
-432
lines changed

lib/LuxCore/test/Project.toml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,23 @@ Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
66
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
77
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
88
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
9+
ParallelTestRunner = "d3525ed8-44d0-4b2c-a655-542cee43accc"
910
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1011
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
1112
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1213

14+
[sources]
15+
LuxCore = {path = ".."}
16+
MLDataDevices = {path = "../../MLDataDevices"}
17+
1318
[compat]
1419
Aqua = "0.8.7"
1520
EnzymeCore = "0.8.14"
1621
ExplicitImports = "1.9.0"
1722
Functors = "0.5"
1823
MLDataDevices = "1.17"
1924
Optimisers = "0.3.4, 0.4"
25+
ParallelTestRunner = "2.1"
2026
Random = "1.10"
2127
Setfield = "1.1"
2228
Test = "1.10"
23-
24-
[sources]
25-
LuxCore = {path = ".."}
26-
MLDataDevices = {path = "../../MLDataDevices"}
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
using LuxCore, Test, Random
2+
3+
rng = LuxCore.Internal.default_rng()
4+
5+
include("common.jl")
6+
7+
@testset "AbstractLuxContainerLayer Interface" begin
8+
model = Chain((; layer_1=Dense(5, 5), layer_2=Dense(5, 6)))
9+
x = randn(rng, Float32, 5)
10+
ps, st = LuxCore.setup(rng, model)
11+
12+
@test fieldnames(typeof(ps)) == (:layers,)
13+
@test fieldnames(typeof(st)) == (:layers,)
14+
15+
@test LuxCore.parameterlength(ps) ==
16+
LuxCore.parameterlength(model) ==
17+
LuxCore.parameterlength(model.layers[1]) + LuxCore.parameterlength(model.layers[2])
18+
@test LuxCore.statelength(st) ==
19+
LuxCore.statelength(model) ==
20+
LuxCore.statelength(model.layers[1]) + LuxCore.statelength(model.layers[2])
21+
22+
@test LuxCore.apply(model, x, ps, st) == model(x, ps, st)
23+
24+
@test LuxCore.stateless_apply(model, x, ps) == first(LuxCore.apply(model, x, ps, st))
25+
26+
@test_nowarn println(model)
27+
28+
model = Chain2(Dense(5, 5), Dense(5, 6))
29+
x = randn(rng, Float32, 5)
30+
ps, st = LuxCore.setup(rng, model)
31+
32+
@test LuxCore.parameterlength(ps) ==
33+
LuxCore.parameterlength(model) ==
34+
LuxCore.parameterlength(model.layer1) + LuxCore.parameterlength(model.layer2)
35+
@test LuxCore.statelength(st) ==
36+
LuxCore.statelength(model) ==
37+
LuxCore.statelength(model.layer1) + LuxCore.statelength(model.layer2)
38+
39+
@test LuxCore.apply(model, x, ps, st) == model(x, ps, st)
40+
41+
@test LuxCore.stateless_apply(model, x, ps) == first(LuxCore.apply(model, x, ps, st))
42+
43+
@test_nowarn println(model)
44+
end
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
using LuxCore, Test, Random, Functors
2+
3+
rng = LuxCore.Internal.default_rng()
4+
5+
include("common.jl")
6+
7+
@testset "AbstractLuxLayer Interface" begin
8+
@testset "Custom Layer" begin
9+
model = Dense(5, 6)
10+
x = randn(rng, Float32, 5)
11+
ps, st = LuxCore.setup(rng, model)
12+
13+
@test LuxCore.parameterlength(ps) == LuxCore.parameterlength(model)
14+
@test LuxCore.statelength(st) == LuxCore.statelength(model)
15+
16+
@test LuxCore.apply(model, x, ps, st) == model(x, ps, st)
17+
18+
@test LuxCore.stateless_apply(model, x, ps) ==
19+
first(LuxCore.apply(model, x, ps, NamedTuple()))
20+
21+
@test_nowarn println(model)
22+
23+
@testset for wrapper in (DenseWrapper, DenseWrapper2)
24+
model2 = wrapper(model)
25+
ps, st = LuxCore.setup(rng, model2)
26+
27+
@test LuxCore.parameterlength(ps) == LuxCore.parameterlength(model2)
28+
@test LuxCore.statelength(st) == LuxCore.statelength(model2)
29+
30+
@test model2(x, ps, st)[1] == model(x, ps, st)[1]
31+
32+
@test_nowarn println(model2)
33+
end
34+
end
35+
36+
@testset "Default Fallbacks" begin
37+
struct NoParamStateLayer <: AbstractLuxLayer end
38+
39+
layer = NoParamStateLayer()
40+
@test LuxCore.initialparameters(rng, layer) == NamedTuple()
41+
@test LuxCore.initialstates(rng, layer) == NamedTuple()
42+
43+
@test LuxCore.parameterlength(zeros(10, 2)) == 20
44+
@test LuxCore.statelength(zeros(10, 2)) == 20
45+
@test LuxCore.statelength(Val(true)) == 1
46+
@test LuxCore.statelength((zeros(10), zeros(5, 2))) == 20
47+
@test LuxCore.statelength((layer_1=zeros(10), layer_2=zeros(5, 2))) == 20
48+
49+
@test LuxCore.initialparameters(rng, NamedTuple()) == NamedTuple()
50+
@test_throws MethodError LuxCore.initialparameters(rng, ())
51+
@test LuxCore.initialparameters(rng, nothing) == NamedTuple()
52+
@test LuxCore.initialparameters(rng, (nothing, layer)) ==
53+
(NamedTuple(), NamedTuple())
54+
55+
@test LuxCore.initialstates(rng, NamedTuple()) == NamedTuple()
56+
@test_throws MethodError LuxCore.initialstates(rng, ())
57+
@test LuxCore.initialstates(rng, nothing) == NamedTuple()
58+
@test LuxCore.initialparameters(rng, (nothing, layer)) ==
59+
(NamedTuple(), NamedTuple())
60+
end
61+
end
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
using LuxCore, Test, Random
2+
3+
rng = LuxCore.Internal.default_rng()
4+
5+
include("common.jl")
6+
7+
@testset "AbstractLuxWrapperLayer Interface" begin
8+
model = ChainWrapper((; layer_1=Dense(5, 10), layer_2=Dense(10, 5)))
9+
x = randn(rng, Float32, 5)
10+
ps, st = LuxCore.setup(rng, model)
11+
12+
@test fieldnames(typeof(ps)) == (:layer_1, :layer_2)
13+
@test fieldnames(typeof(st)) == (:layer_1, :layer_2)
14+
15+
@test LuxCore.parameterlength(ps) ==
16+
LuxCore.parameterlength(model) ==
17+
LuxCore.parameterlength(model.layers.layer_1) +
18+
LuxCore.parameterlength(model.layers.layer_2)
19+
@test LuxCore.statelength(st) ==
20+
LuxCore.statelength(model) ==
21+
LuxCore.statelength(model.layers.layer_1) +
22+
LuxCore.statelength(model.layers.layer_2)
23+
24+
@test LuxCore.apply(model, x, ps, st) == model(x, ps, st)
25+
26+
@test LuxCore.stateless_apply(model, x, ps) == first(LuxCore.apply(model, x, ps, st))
27+
28+
@test_nowarn println(model)
29+
end

lib/LuxCore/test/common.jl

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
using LuxCore, Random
2+
3+
# Define some custom layers
4+
struct Dense <: AbstractLuxLayer
5+
in::Int
6+
out::Int
7+
end
8+
9+
function LuxCore.initialparameters(rng::AbstractRNG, l::Dense)
10+
return (w=randn(rng, l.out, l.in), b=randn(rng, l.out))
11+
end
12+
13+
(::Dense)(x, _, st) = x, st # Dummy Forward Pass
14+
15+
struct DenseWrapper{L} <: AbstractLuxWrapperLayer{:layer}
16+
layer::L
17+
end
18+
19+
# For checking ambiguities in the dispatch
20+
struct DenseWrapper2{L} <: AbstractLuxWrapperLayer{:layer}
21+
layer::L
22+
end
23+
24+
(d::DenseWrapper2)(x::AbstractArray, ps, st) = d.layer(x, ps, st)
25+
26+
struct Chain{L} <: AbstractLuxContainerLayer{(:layers,)}
27+
layers::L
28+
end
29+
30+
function (c::Chain)(x, ps, st)
31+
y, st1 = c.layers[1](x, ps.layers.layer_1, st.layers.layer_1)
32+
y, st2 = c.layers[2](y, ps.layers.layer_2, st.layers.layer_2)
33+
return y, (; layers=(; layer_1=st1, layer_2=st2))
34+
end
35+
36+
struct ChainWrapper{L} <: AbstractLuxWrapperLayer{:layers}
37+
layers::L
38+
end
39+
40+
function (c::ChainWrapper)(x, ps, st)
41+
y, st1 = c.layers[1](x, ps.layer_1, st.layer_1)
42+
y, st2 = c.layers[2](y, ps.layer_2, st.layer_2)
43+
return y, (; layer_1=st1, layer_2=st2)
44+
end
45+
46+
struct Chain2{L1,L2} <: AbstractLuxContainerLayer{(:layer1, :layer2)}
47+
layer1::L1
48+
layer2::L2
49+
end
50+
51+
function (c::Chain2)(x, ps, st)
52+
y, st1 = c.layer1(x, ps.layer1, st.layer1)
53+
y, st2 = c.layer2(y, ps.layer2, st.layer2)
54+
return y, (; layer1=st1, layer2=st2)
55+
end

0 commit comments

Comments
 (0)