Skip to content

Commit 04ef36a

Browse files
author
Avik Pal
committed
test: more recurrent testing fixes
1 parent f135943 commit 04ef36a

File tree

1 file changed

+147
-100
lines changed

1 file changed

+147
-100
lines changed

test/layers/recurrent_tests.jl

Lines changed: 147 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
using MLDataDevices
44

5-
MLDataDevices.get_device_type(::Function) = Nothing # FIXME: upstream maybe?
6-
MLDataDevices.get_device_type(_) = Nothing # FIXME: upstream maybe?
5+
MLDataDevices.Internal.get_device_type(::Function) = Nothing # FIXME: upstream maybe?
6+
MLDataDevices.Internal.get_device_type(_) = Nothing # FIXME: upstream maybe?
77

88
function loss_loop(cell, x, p, st)
99
(y, carry), st_ = cell(x, p, st)
@@ -43,9 +43,9 @@ end
4343
@jet rnncell((x, carry), ps, st)
4444

4545
if train_state
46-
@test hasproperty(ps, :train_state)
46+
@test hasproperty(ps, :hidden_state)
4747
else
48-
@test !hasproperty(ps, :train_state)
48+
@test !hasproperty(ps, :hidden_state)
4949
end
5050

5151
@test_gradients(loss_loop, rnncell, x, ps, st; atol=1.0f-3, rtol=1.0f-3)
@@ -95,8 +95,8 @@ end
9595
@jet lstmcell(x, ps, st)
9696
@jet lstmcell((x, carry), ps, st)
9797

98-
@test !hasproperty(ps, :train_state)
99-
@test !hasproperty(ps, :train_memory)
98+
@test !hasproperty(ps, :hidden_state)
99+
@test !hasproperty(ps, :memory)
100100

101101
@test_gradients(loss_loop, lstmcell, x, ps, st; atol=1.0f-3, rtol=1.0f-3)
102102
end
@@ -198,7 +198,7 @@ end
198198
@jet grucell(x, ps, st)
199199
@jet grucell((x, carry), ps, st)
200200

201-
@test !hasproperty(ps, :train_state)
201+
@test !hasproperty(ps, :hidden_state)
202202

203203
@test_gradients(loss_loop, grucell, x, ps, st; atol=1e-3, rtol=1e-3)
204204
end
@@ -276,94 +276,138 @@ end
276276
st__ = Lux.update_state(st, :carry, nothing)
277277
@test st__.carry === nothing
278278

279-
@test_gradients(loss_loop_no_carry, rnn, x, ps, st; atol=1e-3, rtol=1e-3)
279+
@test_gradients(loss_loop_no_carry, rnn, x, ps, st; atol=1e-3, rtol=1e-3,
280+
soft_fail=[AutoFiniteDiff()])
280281
end
281282
end
282283
end
283284
end
284285

285-
@testitem "Recurrence" setup=[SharedTestSetup] tags=[:recurrent_layers] begin
286+
@testsetup module RecurrenceTestSetup
287+
288+
using LuxTestUtils, StableRNGs, Test, Lux
289+
290+
function test_recurrence_layer(
291+
mode, aType, dev, ongpu, ordering, _cell, use_bias, train_state)
286292
rng = StableRNG(12345)
287293

288-
@testset "$mode" for (mode, aType, dev, ongpu) in MODES
289-
@testset for ordering in (BatchLastIndex(), TimeLastIndex())
290-
@testset for _cell in (RNNCell, LSTMCell, GRUCell)
291-
@testset for use_bias in (true, false), train_state in (true, false)
292-
cell = _cell(3 => 5; use_bias, train_state)
293-
rnn = Recurrence(cell; ordering)
294-
rnn_seq = Recurrence(cell; ordering, return_sequence=true)
295-
display(rnn)
296-
297-
# Batched Time Series
298-
@testset "typeof(x): $(typeof(x))" for x in (
299-
randn(rng, Float32, 3, 4, 2) |> aType,
300-
Tuple(randn(rng, Float32, 3, 2) for _ in 1:4) .|> aType,
301-
[randn(rng, Float32, 3, 2) for _ in 1:4] .|> aType)
302-
# Fix data ordering for testing
303-
if ordering isa TimeLastIndex && x isa AbstractArray && ndims(x) 2
304-
x = permutedims(x,
305-
(ntuple(identity, ndims(x) - 2)..., ndims(x), ndims(x) - 1))
306-
end
307-
308-
ps, st = Lux.setup(rng, rnn) |> dev
309-
y, st_ = rnn(x, ps, st)
310-
y_, st__ = rnn_seq(x, ps, st)
311-
312-
@jet rnn(x, ps, st)
313-
@jet rnn_seq(x, ps, st)
314-
315-
@test size(y) == (5, 2)
316-
@test length(y_) == 4
317-
@test all(x -> size(x) == (5, 2), y_)
318-
319-
__f = p -> sum(first(rnn(x, p, st)))
320-
@test_gradients(__f, ps; atol=1e-3, rtol=1e-3,
321-
skip_backends=[AutoEnzyme()], soft_fail=[AutoFiniteDiff()])
322-
323-
__f = p -> sum(Base.Fix1(sum, abs2), first(rnn_seq(x, p, st)))
324-
@test_gradients(__f, ps; atol=1e-3, rtol=1e-3,
325-
skip_backends=[AutoEnzyme()], soft_fail=[AutoFiniteDiff()])
326-
end
327-
328-
# Batched Time Series without data batches
329-
@testset "typeof(x): $(typeof(x))" for x in (
330-
randn(rng, Float32, 3, 4) |> aType,
331-
Tuple(randn(rng, Float32, 3) for _ in 1:4) .|> aType,
332-
[randn(rng, Float32, 3) for _ in 1:4] .|> aType)
333-
ps, st = Lux.setup(rng, rnn) |> dev
334-
y, st_ = rnn(x, ps, st)
335-
y_, st__ = rnn_seq(x, ps, st)
336-
337-
@jet rnn(x, ps, st)
338-
@jet rnn_seq(x, ps, st)
339-
340-
@test size(y) == (5,)
341-
@test length(y_) == 4
342-
@test all(x -> size(x) == (5,), y_)
343-
344-
if x isa AbstractMatrix && ordering isa BatchLastIndex
345-
x2 = reshape(x, Val(3))
346-
347-
y2, _ = rnn(x2, ps, st)
348-
@test y == vec(y2)
349-
350-
y2_, _ = rnn_seq(x2, ps, st)
351-
@test all(x -> x[1] == vec(x[2]), zip(y_, y2_))
352-
end
353-
354-
__f = p -> sum(first(rnn(x, p, st)))
355-
@test_gradients(__f, ps; atol=1e-3, rtol=1e-3,
356-
skip_backends=[AutoEnzyme()], soft_fail=[AutoFiniteDiff()])
357-
358-
__f = p -> sum(Base.Fix1(sum, abs2), first(rnn_seq(x, p, st)))
359-
@test_gradients(__f, ps; atol=1e-3, rtol=1e-3,
360-
skip_backends=[AutoEnzyme()], soft_fail=[AutoFiniteDiff()])
361-
end
362-
end
363-
end
294+
cell = _cell(3 => 5; use_bias, train_state)
295+
rnn = Recurrence(cell; ordering)
296+
display(rnn)
297+
rnn_seq = Recurrence(cell; ordering, return_sequence=true)
298+
display(rnn_seq)
299+
300+
# Batched Time Series
301+
@testset "typeof(x): $(typeof(x))" for x in (
302+
randn(rng, Float32, 3, 4, 2) |> aType,
303+
Tuple(randn(rng, Float32, 3, 2) for _ in 1:4) .|> aType,
304+
[randn(rng, Float32, 3, 2) for _ in 1:4] .|> aType)
305+
# Fix data ordering for testing
306+
if ordering isa TimeLastIndex && x isa AbstractArray && ndims(x) 2
307+
x = permutedims(x,
308+
(ntuple(identity, ndims(x) - 2)..., ndims(x), ndims(x) - 1))
309+
end
310+
311+
ps, st = Lux.setup(rng, rnn) |> dev
312+
y, st_ = rnn(x, ps, st)
313+
y_, st__ = rnn_seq(x, ps, st)
314+
315+
@test size(y) == (5, 2)
316+
@test length(y_) == 4
317+
@test all(x -> size(x) == (5, 2), y_)
318+
319+
__f = ps -> sum(abs2, first(rnn(x, ps, st)))
320+
@test_gradients(__f, ps; atol=1.0f-3, rtol=1.0f-3, skip_backends=[AutoEnzyme()])
321+
322+
__f = ps -> sum(Base.Fix1(sum, abs2), first(rnn_seq(x, ps, st)))
323+
@test_gradients(__f, ps; atol=1.0f-3, rtol=1.0f-3, skip_backends=[AutoEnzyme()])
324+
end
325+
326+
# Batched Time Series without data batches
327+
@testset "typeof(x): $(typeof(x))" for x in (
328+
randn(rng, Float32, 3, 4) |> aType,
329+
Tuple(randn(rng, Float32, 3) for _ in 1:4) .|> aType,
330+
[randn(rng, Float32, 3) for _ in 1:4] .|> aType)
331+
ps, st = Lux.setup(rng, rnn) |> dev
332+
y, st_ = rnn(x, ps, st)
333+
y_, st__ = rnn_seq(x, ps, st)
334+
335+
@test size(y) == (5,)
336+
@test length(y_) == 4
337+
@test all(x -> size(x) == (5,), y_)
338+
339+
if x isa AbstractMatrix && ordering isa BatchLastIndex
340+
x2 = reshape(x, Val(3))
341+
y2, _ = rnn(x2, ps, st)
342+
@test y == vec(y2)
343+
y2_, _ = rnn_seq(x2, ps, st)
344+
@test all(x -> x[1] == vec(x[2]), zip(y_, y2_))
345+
end
346+
347+
__f = ps -> sum(abs2, first(rnn(x, ps, st)))
348+
@test_gradients(__f, ps; atol=1.0f-3, rtol=1.0f-3, skip_backends=[AutoEnzyme()])
349+
350+
__f = ps -> sum(Base.Fix1(sum, abs2), first(rnn(x, ps, st)))
351+
@test_gradients(__f, ps; atol=1.0f-3, rtol=1.0f-3, skip_backends=[AutoEnzyme()])
352+
end
353+
end
354+
355+
const ALL_TEST_CONFIGS = Iterators.product(
356+
(BatchLastIndex(), TimeLastIndex()),
357+
(RNNCell, LSTMCell, GRUCell),
358+
(true, false),
359+
(true, false))
360+
361+
const TEST_BLOCKS = collect(Iterators.partition(
362+
ALL_TEST_CONFIGS, ceil(Int, length(ALL_TEST_CONFIGS) / 4)))
363+
364+
export TEST_BLOCKS, test_recurrence_layer
365+
366+
end
367+
368+
@testitem "Recurrence: Group 1" setup=[
369+
RecurrenceTestSetup, SharedTestSetup, RecurrentLayersSetup] tags=[:recurrent_layers] begin
370+
@testset "$(mode)" for (mode, aType, dev, ongpu) in MODES
371+
@testset for (ordering, cell, use_bias, train_state) in TEST_BLOCKS[1]
372+
test_recurrence_layer(
373+
mode, aType, dev, ongpu, ordering, cell, use_bias, train_state)
374+
end
375+
end
376+
end
377+
378+
@testitem "Recurrence: Group 2" setup=[
379+
RecurrenceTestSetup, SharedTestSetup, RecurrentLayersSetup] tags=[:recurrent_layers] begin
380+
@testset "$(mode)" for (mode, aType, dev, ongpu) in MODES
381+
@testset for (ordering, cell, use_bias, train_state) in TEST_BLOCKS[2]
382+
test_recurrence_layer(
383+
mode, aType, dev, ongpu, ordering, cell, use_bias, train_state)
384+
end
385+
end
386+
end
387+
388+
@testitem "Recurrence: Group 3" setup=[
389+
RecurrenceTestSetup, SharedTestSetup, RecurrentLayersSetup] tags=[:recurrent_layers] begin
390+
@testset "$(mode)" for (mode, aType, dev, ongpu) in MODES
391+
@testset for (ordering, cell, use_bias, train_state) in TEST_BLOCKS[3]
392+
test_recurrence_layer(
393+
mode, aType, dev, ongpu, ordering, cell, use_bias, train_state)
394+
end
395+
end
396+
end
397+
398+
@testitem "Recurrence: Group 4" setup=[
399+
RecurrenceTestSetup, SharedTestSetup, RecurrentLayersSetup] tags=[:recurrent_layers] begin
400+
@testset "$(mode)" for (mode, aType, dev, ongpu) in MODES
401+
@testset for (ordering, cell, use_bias, train_state) in TEST_BLOCKS[4]
402+
test_recurrence_layer(
403+
mode, aType, dev, ongpu, ordering, cell, use_bias, train_state)
364404
end
405+
end
406+
end
365407

366-
# Ordering Check: https://github.com/LuxDL/Lux.jl/issues/302
408+
@testitem "Recurrence Ordering Check #302" setup=[SharedTestSetup] tags=[:recurrent_layers] begin
409+
rng = StableRNG(12345)
410+
@testset "$mode" for (mode, aType, dev, ongpu) in MODES
367411
encoder = Recurrence(
368412
RNNCell(1 => 1, identity;
369413
init_weight=(rng, args...; kwargs...) -> ones(args...; kwargs...),
@@ -378,7 +422,7 @@ end
378422
end
379423
end
380424

381-
@testitem "Bidirectional" setup=[SharedTestSetup] tags=[:recurrent_layers] begin
425+
@testitem "Bidirectional" setup=[SharedTestSetup, RecurrentLayersSetup] tags=[:recurrent_layers] begin
382426
rng = StableRNG(12345)
383427

384428
@testset "$mode" for (mode, aType, dev, ongpu) in MODES
@@ -405,17 +449,18 @@ end
405449
@test size(y_[1]) == (4,)
406450
@test all(x -> size(x) == (5, 2), y_[1])
407451

408-
__f = p -> sum(Base.Fix1(sum, abs2), first(bi_rnn(x, p, st)))
409-
@test_gradients(__f, ps; atol=1e-3, rtol=1e-3, broken_backends=[AutoEnzyme()])
452+
__f = (bi_rnn, x, ps, st) -> sum(Base.Fix1(sum, abs2), first(bi_rnn(x, ps, st)))
453+
@test_gradients(__f, bi_rnn, x, ps, st; atol=1e-3, rtol=1e-3,
454+
broken_backends=[AutoEnzyme()])
410455

411-
__f = p -> begin
412-
(y1, y2), st_ = bi_rnn_no_merge(x, p, st)
456+
__f = (bi_rnn_no_merge, x, ps, st) -> begin
457+
(y1, y2), st_ = bi_rnn_no_merge(x, ps, st)
413458
return sum(Base.Fix1(sum, abs2), y1) + sum(Base.Fix1(sum, abs2), y2)
414459
end
415-
@test_gradients(__f, ps; atol=1e-3, rtol=1e-3, broken_backends=[AutoEnzyme()])
460+
@test_gradients(__f, bi_rnn_no_merge, x, ps, st; atol=1e-3,
461+
rtol=1e-3, broken_backends=[AutoEnzyme()])
416462

417-
@testset "backward_cell: $_backward_cell" for _backward_cell in (
418-
RNNCell, LSTMCell, GRUCell)
463+
@testset for _backward_cell in (RNNCell, LSTMCell, GRUCell)
419464
cell = _cell(3 => 5)
420465
backward_cell = _backward_cell(3 => 5)
421466
bi_rnn = BidirectionalRNN(cell, backward_cell)
@@ -439,16 +484,18 @@ end
439484
@test size(y_[1]) == (4,)
440485
@test all(x -> size(x) == (5, 2), y_[1])
441486

442-
__f = p -> sum(Base.Fix1(sum, abs2), first(bi_rnn(x, p, st)))
443-
@test_gradients(__f, ps; atol=1e-3, rtol=1e-3,
487+
__f = (bi_rnn, x, ps, st) -> sum(
488+
Base.Fix1(sum, abs2), first(bi_rnn(x, ps, st)))
489+
@test_gradients(__f, bi_rnn, x, ps, st; atol=1e-3,
490+
rtol=1e-3,
444491
broken_backends=[AutoEnzyme()])
445492

446-
__f = p -> begin
447-
(y1, y2), st_ = bi_rnn_no_merge(x, p, st)
493+
__f = (bi_rnn_no_merge, x, ps, st) -> begin
494+
(y1, y2), st_ = bi_rnn_no_merge(x, ps, st)
448495
return sum(Base.Fix1(sum, abs2), y1) + sum(Base.Fix1(sum, abs2), y2)
449496
end
450-
@test_gradients(__f, ps; atol=1e-3, rtol=1e-3,
451-
broken_backends=[AutoEnzyme()])
497+
@test_gradients(__f, bi_rnn_no_merge, x, ps, st; atol=1e-3,
498+
rtol=1e-3, broken_backends=[AutoEnzyme()])
452499
end
453500
end
454501
end

0 commit comments

Comments
 (0)