22
33using MLDataDevices
44
5- MLDataDevices. get_device_type(:: Function ) = Nothing # FIXME : upstream maybe?
6- MLDataDevices. get_device_type(_) = Nothing # FIXME : upstream maybe?
5+ MLDataDevices. Internal . get_device_type(:: Function ) = Nothing # FIXME : upstream maybe?
6+ MLDataDevices. Internal . get_device_type(_) = Nothing # FIXME : upstream maybe?
77
88function loss_loop(cell, x, p, st)
99 (y, carry), st_ = cell(x, p, st)
4343 @jet rnncell((x, carry), ps, st)
4444
4545 if train_state
46- @test hasproperty(ps, :train_state )
46+ @test hasproperty(ps, :hidden_state )
4747 else
48- @test ! hasproperty(ps, :train_state )
48+ @test ! hasproperty(ps, :hidden_state )
4949 end
5050
5151 @test_gradients(loss_loop, rnncell, x, ps, st; atol= 1.0f-3 , rtol= 1.0f-3 )
9595 @jet lstmcell(x, ps, st)
9696 @jet lstmcell((x, carry), ps, st)
9797
98- @test ! hasproperty(ps, :train_state )
99- @test ! hasproperty(ps, :train_memory )
98+ @test ! hasproperty(ps, :hidden_state )
99+ @test ! hasproperty(ps, :memory )
100100
101101 @test_gradients(loss_loop, lstmcell, x, ps, st; atol= 1.0f-3 , rtol= 1.0f-3 )
102102 end
198198 @jet grucell(x, ps, st)
199199 @jet grucell((x, carry), ps, st)
200200
201- @test ! hasproperty(ps, :train_state )
201+ @test ! hasproperty(ps, :hidden_state )
202202
203203 @test_gradients(loss_loop, grucell, x, ps, st; atol= 1e-3 , rtol= 1e-3 )
204204 end
@@ -276,94 +276,138 @@ end
276276 st__ = Lux. update_state(st, :carry, nothing )
277277 @test st__. carry === nothing
278278
279- @test_gradients(loss_loop_no_carry, rnn, x, ps, st; atol= 1e-3 , rtol= 1e-3 )
279+ @test_gradients(loss_loop_no_carry, rnn, x, ps, st; atol= 1e-3 , rtol= 1e-3 ,
280+ soft_fail= [AutoFiniteDiff()])
280281 end
281282 end
282283 end
283284end
284285
285- @testitem " Recurrence" setup= [SharedTestSetup] tags= [:recurrent_layers] begin
286+ @testsetup module RecurrenceTestSetup
287+
288+ using LuxTestUtils, StableRNGs, Test, Lux
289+
290+ function test_recurrence_layer(
291+ mode, aType, dev, ongpu, ordering, _cell, use_bias, train_state)
286292 rng = StableRNG(12345 )
287293
288- @testset " $mode " for (mode, aType, dev, ongpu) in MODES
289- @testset for ordering in (BatchLastIndex(), TimeLastIndex())
290- @testset for _cell in (RNNCell, LSTMCell, GRUCell)
291- @testset for use_bias in (true , false ), train_state in (true , false )
292- cell = _cell(3 => 5 ; use_bias, train_state)
293- rnn = Recurrence(cell; ordering)
294- rnn_seq = Recurrence(cell; ordering, return_sequence= true )
295- display(rnn)
296-
297- # Batched Time Series
298- @testset " typeof(x): $(typeof(x)) " for x in (
299- randn(rng, Float32, 3 , 4 , 2 ) |> aType,
300- Tuple(randn(rng, Float32, 3 , 2 ) for _ in 1 : 4 ) .| > aType,
301- [randn(rng, Float32, 3 , 2 ) for _ in 1 : 4 ] .| > aType)
302- # Fix data ordering for testing
303- if ordering isa TimeLastIndex && x isa AbstractArray && ndims(x) ≥ 2
304- x = permutedims(x,
305- (ntuple(identity, ndims(x) - 2 ). .. , ndims(x), ndims(x) - 1 ))
306- end
307-
308- ps, st = Lux. setup(rng, rnn) |> dev
309- y, st_ = rnn(x, ps, st)
310- y_, st__ = rnn_seq(x, ps, st)
311-
312- @jet rnn(x, ps, st)
313- @jet rnn_seq(x, ps, st)
314-
315- @test size(y) == (5 , 2 )
316- @test length(y_) == 4
317- @test all(x -> size(x) == (5 , 2 ), y_)
318-
319- __f = p -> sum(first(rnn(x, p, st)))
320- @test_gradients(__f, ps; atol= 1e-3 , rtol= 1e-3 ,
321- skip_backends= [AutoEnzyme()], soft_fail= [AutoFiniteDiff()])
322-
323- __f = p -> sum(Base. Fix1(sum, abs2), first(rnn_seq(x, p, st)))
324- @test_gradients(__f, ps; atol= 1e-3 , rtol= 1e-3 ,
325- skip_backends= [AutoEnzyme()], soft_fail= [AutoFiniteDiff()])
326- end
327-
328- # Batched Time Series without data batches
329- @testset " typeof(x): $(typeof(x)) " for x in (
330- randn(rng, Float32, 3 , 4 ) |> aType,
331- Tuple(randn(rng, Float32, 3 ) for _ in 1 : 4 ) .| > aType,
332- [randn(rng, Float32, 3 ) for _ in 1 : 4 ] .| > aType)
333- ps, st = Lux. setup(rng, rnn) |> dev
334- y, st_ = rnn(x, ps, st)
335- y_, st__ = rnn_seq(x, ps, st)
336-
337- @jet rnn(x, ps, st)
338- @jet rnn_seq(x, ps, st)
339-
340- @test size(y) == (5 ,)
341- @test length(y_) == 4
342- @test all(x -> size(x) == (5 ,), y_)
343-
344- if x isa AbstractMatrix && ordering isa BatchLastIndex
345- x2 = reshape(x, Val(3 ))
346-
347- y2, _ = rnn(x2, ps, st)
348- @test y == vec(y2)
349-
350- y2_, _ = rnn_seq(x2, ps, st)
351- @test all(x -> x[1 ] == vec(x[2 ]), zip(y_, y2_))
352- end
353-
354- __f = p -> sum(first(rnn(x, p, st)))
355- @test_gradients(__f, ps; atol= 1e-3 , rtol= 1e-3 ,
356- skip_backends= [AutoEnzyme()], soft_fail= [AutoFiniteDiff()])
357-
358- __f = p -> sum(Base. Fix1(sum, abs2), first(rnn_seq(x, p, st)))
359- @test_gradients(__f, ps; atol= 1e-3 , rtol= 1e-3 ,
360- skip_backends= [AutoEnzyme()], soft_fail= [AutoFiniteDiff()])
361- end
362- end
363- end
294+ cell = _cell(3 => 5 ; use_bias, train_state)
295+ rnn = Recurrence(cell; ordering)
296+ display(rnn)
297+ rnn_seq = Recurrence(cell; ordering, return_sequence= true )
298+ display(rnn_seq)
299+
300+ # Batched Time Series
301+ @testset " typeof(x): $(typeof(x)) " for x in (
302+ randn(rng, Float32, 3 , 4 , 2 ) |> aType,
303+ Tuple(randn(rng, Float32, 3 , 2 ) for _ in 1 : 4 ) .| > aType,
304+ [randn(rng, Float32, 3 , 2 ) for _ in 1 : 4 ] .| > aType)
305+ # Fix data ordering for testing
306+ if ordering isa TimeLastIndex && x isa AbstractArray && ndims(x) ≥ 2
307+ x = permutedims(x,
308+ (ntuple(identity, ndims(x) - 2 ). .. , ndims(x), ndims(x) - 1 ))
309+ end
310+
311+ ps, st = Lux. setup(rng, rnn) |> dev
312+ y, st_ = rnn(x, ps, st)
313+ y_, st__ = rnn_seq(x, ps, st)
314+
315+ @test size(y) == (5 , 2 )
316+ @test length(y_) == 4
317+ @test all(x -> size(x) == (5 , 2 ), y_)
318+
319+ __f = ps -> sum(abs2, first(rnn(x, ps, st)))
320+ @test_gradients(__f, ps; atol= 1.0f-3 , rtol= 1.0f-3 , skip_backends= [AutoEnzyme()])
321+
322+ __f = ps -> sum(Base. Fix1(sum, abs2), first(rnn_seq(x, ps, st)))
323+ @test_gradients(__f, ps; atol= 1.0f-3 , rtol= 1.0f-3 , skip_backends= [AutoEnzyme()])
324+ end
325+
326+ # Batched Time Series without data batches
327+ @testset " typeof(x): $(typeof(x)) " for x in (
328+ randn(rng, Float32, 3 , 4 ) |> aType,
329+ Tuple(randn(rng, Float32, 3 ) for _ in 1 : 4 ) .| > aType,
330+ [randn(rng, Float32, 3 ) for _ in 1 : 4 ] .| > aType)
331+ ps, st = Lux. setup(rng, rnn) |> dev
332+ y, st_ = rnn(x, ps, st)
333+ y_, st__ = rnn_seq(x, ps, st)
334+
335+ @test size(y) == (5 ,)
336+ @test length(y_) == 4
337+ @test all(x -> size(x) == (5 ,), y_)
338+
339+ if x isa AbstractMatrix && ordering isa BatchLastIndex
340+ x2 = reshape(x, Val(3 ))
341+ y2, _ = rnn(x2, ps, st)
342+ @test y == vec(y2)
343+ y2_, _ = rnn_seq(x2, ps, st)
344+ @test all(x -> x[1 ] == vec(x[2 ]), zip(y_, y2_))
345+ end
346+
347+ __f = ps -> sum(abs2, first(rnn(x, ps, st)))
348+ @test_gradients(__f, ps; atol= 1.0f-3 , rtol= 1.0f-3 , skip_backends= [AutoEnzyme()])
349+
350+ __f = ps -> sum(Base. Fix1(sum, abs2), first(rnn(x, ps, st)))
351+ @test_gradients(__f, ps; atol= 1.0f-3 , rtol= 1.0f-3 , skip_backends= [AutoEnzyme()])
352+ end
353+ end
354+
355+ const ALL_TEST_CONFIGS = Iterators. product(
356+ (BatchLastIndex(), TimeLastIndex()),
357+ (RNNCell, LSTMCell, GRUCell),
358+ (true , false ),
359+ (true , false ))
360+
361+ const TEST_BLOCKS = collect(Iterators. partition(
362+ ALL_TEST_CONFIGS, ceil(Int, length(ALL_TEST_CONFIGS) / 4 )))
363+
364+ export TEST_BLOCKS, test_recurrence_layer
365+
366+ end
367+
368+ @testitem " Recurrence: Group 1" setup= [
369+ RecurrenceTestSetup, SharedTestSetup, RecurrentLayersSetup] tags= [:recurrent_layers] begin
370+ @testset " $(mode) " for (mode, aType, dev, ongpu) in MODES
371+ @testset for (ordering, cell, use_bias, train_state) in TEST_BLOCKS[1 ]
372+ test_recurrence_layer(
373+ mode, aType, dev, ongpu, ordering, cell, use_bias, train_state)
374+ end
375+ end
376+ end
377+
378+ @testitem " Recurrence: Group 2" setup= [
379+ RecurrenceTestSetup, SharedTestSetup, RecurrentLayersSetup] tags= [:recurrent_layers] begin
380+ @testset " $(mode) " for (mode, aType, dev, ongpu) in MODES
381+ @testset for (ordering, cell, use_bias, train_state) in TEST_BLOCKS[2 ]
382+ test_recurrence_layer(
383+ mode, aType, dev, ongpu, ordering, cell, use_bias, train_state)
384+ end
385+ end
386+ end
387+
388+ @testitem " Recurrence: Group 3" setup= [
389+ RecurrenceTestSetup, SharedTestSetup, RecurrentLayersSetup] tags= [:recurrent_layers] begin
390+ @testset " $(mode) " for (mode, aType, dev, ongpu) in MODES
391+ @testset for (ordering, cell, use_bias, train_state) in TEST_BLOCKS[3 ]
392+ test_recurrence_layer(
393+ mode, aType, dev, ongpu, ordering, cell, use_bias, train_state)
394+ end
395+ end
396+ end
397+
398+ @testitem " Recurrence: Group 4" setup= [
399+ RecurrenceTestSetup, SharedTestSetup, RecurrentLayersSetup] tags= [:recurrent_layers] begin
400+ @testset " $(mode) " for (mode, aType, dev, ongpu) in MODES
401+ @testset for (ordering, cell, use_bias, train_state) in TEST_BLOCKS[4 ]
402+ test_recurrence_layer(
403+ mode, aType, dev, ongpu, ordering, cell, use_bias, train_state)
364404 end
405+ end
406+ end
365407
366- # Ordering Check: https://github.com/LuxDL/Lux.jl/issues/302
408+ @testitem " Recurrence Ordering Check #302" setup= [SharedTestSetup] tags= [:recurrent_layers] begin
409+ rng = StableRNG(12345 )
410+ @testset " $mode " for (mode, aType, dev, ongpu) in MODES
367411 encoder = Recurrence(
368412 RNNCell(1 => 1 , identity;
369413 init_weight= (rng, args... ; kwargs... ) -> ones(args... ; kwargs... ),
378422 end
379423end
380424
381- @testitem " Bidirectional" setup= [SharedTestSetup] tags= [:recurrent_layers] begin
425+ @testitem " Bidirectional" setup= [SharedTestSetup, RecurrentLayersSetup ] tags= [:recurrent_layers] begin
382426 rng = StableRNG(12345 )
383427
384428 @testset " $mode " for (mode, aType, dev, ongpu) in MODES
@@ -405,17 +449,18 @@ end
405449 @test size(y_[1 ]) == (4 ,)
406450 @test all(x -> size(x) == (5 , 2 ), y_[1 ])
407451
408- __f = p -> sum(Base. Fix1(sum, abs2), first(bi_rnn(x, p, st)))
409- @test_gradients(__f, ps; atol= 1e-3 , rtol= 1e-3 , broken_backends= [AutoEnzyme()])
452+ __f = (bi_rnn, x, ps, st) -> sum(Base. Fix1(sum, abs2), first(bi_rnn(x, ps, st)))
453+ @test_gradients(__f, bi_rnn, x, ps, st; atol= 1e-3 , rtol= 1e-3 ,
454+ broken_backends= [AutoEnzyme()])
410455
411- __f = p -> begin
412- (y1, y2), st_ = bi_rnn_no_merge(x, p , st)
456+ __f = (bi_rnn_no_merge, x, ps, st) -> begin
457+ (y1, y2), st_ = bi_rnn_no_merge(x, ps , st)
413458 return sum(Base. Fix1(sum, abs2), y1) + sum(Base. Fix1(sum, abs2), y2)
414459 end
415- @test_gradients(__f, ps; atol= 1e-3 , rtol= 1e-3 , broken_backends= [AutoEnzyme()])
460+ @test_gradients(__f, bi_rnn_no_merge, x, ps, st; atol= 1e-3 ,
461+ rtol= 1e-3 , broken_backends= [AutoEnzyme()])
416462
417- @testset " backward_cell: $_backward_cell " for _backward_cell in (
418- RNNCell, LSTMCell, GRUCell)
463+ @testset for _backward_cell in (RNNCell, LSTMCell, GRUCell)
419464 cell = _cell(3 => 5 )
420465 backward_cell = _backward_cell(3 => 5 )
421466 bi_rnn = BidirectionalRNN(cell, backward_cell)
@@ -439,16 +484,18 @@ end
439484 @test size(y_[1 ]) == (4 ,)
440485 @test all(x -> size(x) == (5 , 2 ), y_[1 ])
441486
442- __f = p -> sum(Base. Fix1(sum, abs2), first(bi_rnn(x, p, st)))
443- @test_gradients(__f, ps; atol= 1e-3 , rtol= 1e-3 ,
487+ __f = (bi_rnn, x, ps, st) -> sum(
488+ Base. Fix1(sum, abs2), first(bi_rnn(x, ps, st)))
489+ @test_gradients(__f, bi_rnn, x, ps, st; atol= 1e-3 ,
490+ rtol= 1e-3 ,
444491 broken_backends= [AutoEnzyme()])
445492
446- __f = p -> begin
447- (y1, y2), st_ = bi_rnn_no_merge(x, p , st)
493+ __f = (bi_rnn_no_merge, x, ps, st) -> begin
494+ (y1, y2), st_ = bi_rnn_no_merge(x, ps , st)
448495 return sum(Base. Fix1(sum, abs2), y1) + sum(Base. Fix1(sum, abs2), y2)
449496 end
450- @test_gradients(__f, ps ; atol= 1e-3 , rtol = 1e-3 ,
451- broken_backends= [AutoEnzyme()])
497+ @test_gradients(__f, bi_rnn_no_merge, x, ps, st ; atol= 1e-3 ,
498+ rtol = 1e-3 , broken_backends= [AutoEnzyme()])
452499 end
453500 end
454501 end
0 commit comments