Skip to content

Commit c0075a6

Browse files
committed
test: batchnorm layers
1 parent 645ab1a commit c0075a6

File tree

1 file changed

+63
-1
lines changed

1 file changed

+63
-1
lines changed

test/reactant/layer_tests.jl

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,69 @@ end
100100
end
101101
end
102102

103-
@testitem "BatchNorm Layer" tags=[:reactant] setup=[SharedTestSetup] skip=:(Sys.iswindows()) begin
103+
@testitem "BatchNorm Layer" tags=[:reactant] setup=[
104+
SharedTestSetup, SharedReactantLayersTestSetup] skip=:(Sys.iswindows()) begin
104105
using Reactant, Lux, Random
105106

107+
@testset "$(mode)" for (mode, atype, dev, ongpu) in MODES
108+
if mode == "amdgpu"
109+
@warn "Skipping AMDGPU tests for Reactant"
110+
continue
111+
end
112+
113+
dev = reactant_device(; force=true)
114+
115+
if ongpu
116+
Reactant.set_default_backend("gpu")
117+
else
118+
Reactant.set_default_backend("cpu")
119+
end
120+
121+
@testset for track_stats in (true, false), affine in (true, false),
122+
act in (identity, tanh)
123+
124+
model = Chain(
125+
Dense(2 => 3, tanh),
126+
BatchNorm(3, act; track_stats, affine, init_bias=rand32, init_scale=rand32),
127+
Dense(3 => 2)
128+
)
129+
130+
x = rand(Float32, 2, 4)
131+
ps, st = Lux.setup(Random.default_rng(), model)
132+
133+
x_ra = x |> dev
134+
ps_ra = ps |> dev
135+
st_ra = st |> dev
136+
137+
y, st2 = model(x, ps, st)
138+
y_ra, st2_ra = @jit model(x_ra, ps_ra, st_ra)
139+
140+
@test yy_ra rtol=1e-3 atol=1e-3
141+
if track_stats
142+
@test st2.layer_2.running_meanst2_ra.layer_2.running_mean rtol=1e-3 atol=1e-3
143+
@test st2.layer_2.running_varst2_ra.layer_2.running_var rtol=1e-3 atol=1e-3
144+
end
145+
146+
# TODO: Check for stablehlo.batch_norm_training once we emit it in LuxLib
147+
148+
@testset "gradient" begin
149+
∂x, ∂ps = ∇sumabs2_zygote(model, x, ps, st)
150+
∂x_ra, ∂ps_ra = @jit ∇sumabs2_enzyme(model, x_ra, ps_ra, st_ra)
151+
@test ∂x_ra∂x atol=1e-2 rtol=1e-2
152+
@test check_approx(∂ps_ra, ∂ps; atol=1e-2, rtol=1e-2)
153+
end
154+
155+
y2, st3 = model(x, ps, Lux.testmode(st2))
156+
y2_ra, st3_ra = @jit model(x_ra, ps_ra, Lux.testmode(st2_ra))
157+
158+
@test y2y2_ra rtol=1e-3 atol=1e-3
159+
if track_stats
160+
@test st3.layer_2.running_meanst3_ra.layer_2.running_mean rtol=1e-3 atol=1e-3
161+
@test st3.layer_2.running_varst3_ra.layer_2.running_var rtol=1e-3 atol=1e-3
162+
end
163+
164+
hlo = @code_hlo model(x_ra, ps_ra, Lux.testmode(st_ra))
165+
@test contains(repr(hlo), "stablehlo.batch_norm_inference")
166+
end
167+
end
106168
end

0 commit comments

Comments
 (0)