|
358 | 358 | end |
359 | 359 | end # cumprod |
360 | 360 |
|
361 | | - @testset "accumulate(f, ::Array)" begin |
| 361 | + @testset "accumulate(f, ::Vector)" begin |
362 | 362 | # `accumulate(f, A; init)` goes to `_accumulate!(op, B, A, dims::Nothing, init::Nothing)`. |
363 | 363 | # The rule is now attached there, as this is the simplest way to handle `init` keyword. |
364 | | - @eval using Base: _accumulate! |
365 | 364 |
|
366 | 365 | # Simple |
367 | 366 | y1, b1 = rrule(CFG, _accumulate!, *, [0, 0, 0, 0], [1, 2, 3, 4], nothing, Some(1)) |
|
371 | 370 | @test b1([1, 1, 1, 1])[6] isa Tangent{Some{Int64}} |
372 | 371 | @test b1([1, 1, 1, 1])[6].value isa ChainRulesCore.NotImplemented |
373 | 372 |
|
374 | | - y2, b2 = rrule(CFG, accumulate, /, [1 2; 3 4]) |
375 | | - @test y2 ≈ accumulate(/, [1 2; 3 4]) |
376 | | - @test b2(ones(2, 2))[3] ≈ [1.5416666 -0.104166664; -0.18055555 -0.010416667] atol=1e-6 |
| 373 | + # y2, b2 = rrule(CFG, _accumulate!, /, [0 0; 0 0], [1 2; 3 4], :, nothing) |
| 374 | + # @test y2 ≈ accumulate(/, [1 2; 3 4.0]) |
| 375 | + # @test b2(ones(2, 2))[3] ≈ [1.5416666 -0.104166664; -0.18055555 -0.010416667] atol=1e-6 |
377 | 376 |
|
378 | 377 | # Test execution order |
379 | 378 | c3 = Counter() |
@@ -403,35 +402,11 @@ end |
403 | 402 | # ForwardDiff.gradient(z -> sum(accumulate((x,y)->x*y*13, z, init=3)), [5,7,11]) |> string |
404 | 403 |
|
405 | 404 | # Finite differencing |
406 | | - test_rrule(accumulate, *, randn(5); fkwargs=(; init=rand())) |
407 | | - test_rrule(accumulate, /, 1 .+ rand(3, 4)) |
408 | | - test_rrule(accumulate, ^, 1 .+ rand(2, 3); fkwargs=(; init=rand())) |
| 405 | + # test_rrule(accumulate, *, randn(5); fkwargs=(; init=rand())) |
| 406 | + test_rrule(_accumulate!, *, randn(5) ⊢ NoTangent(), randn(5), nothing, Some(rand())) |
| 407 | + # test_rrule(accumulate, /, 1 .+ rand(3, 4)) |
| 408 | + test_rrule(_accumulate!, /, randn(4) ⊢ NoTangent(), 1 .+ rand(4), nothing, nothing) |
| 409 | + # test_rrule(accumulate, ^, 1 .+ rand(2, 3); fkwargs=(; init=rand())) |
| 410 | + test_rrule(_accumulate!, ^, randn(6) ⊢ NoTangent(), 1 .+ rand(6), nothing, Some(rand())) |
409 | 411 | end |
410 | | - @testset "accumulate(f, ::Tuple)" begin |
411 | | - # Simple |
412 | | - y1, b1 = rrule(CFG, accumulate, *, (1, 2, 3, 4); init=1) |
413 | | - @test y1 == (1, 2, 6, 24) |
414 | | - @test b1((1, 1, 1, 1)) == (NoTangent(), NoTangent(), Tangent{NTuple{4,Int}}(33, 16, 10, 6)) |
415 | | - |
416 | | - # Finite differencing |
417 | | - test_rrule(accumulate, *, Tuple(randn(5)); fkwargs=(; init=rand())) |
418 | | - test_rrule(accumulate, /, Tuple(1 .+ rand(5)); check_inferred=false) |
419 | | - |
420 | | - test_rrule(_accumulate!, *, randn(5) ⊢ NoTangent(), randn(5), nothing, nothing) |
421 | | - test_rrule(_accumulate!, /, randn(5) ⊢ NoTangent(), randn(5), nothing, Some(1 + rand())) |
422 | | - # if VERSION >= v"1.5" |
423 | | - # test_rrule(accumulate, /, 1 .+ rand(3, 4)) |
424 | | - # test_rrule(accumulate, ^, 1 .+ rand(2, 3); fkwargs=(; init=rand())) |
425 | | - # end |
426 | | - end |
427 | | - # VERSION >= v"1.5" && @testset "accumulate(f, ::Tuple)" begin |
428 | | - # # Simple |
429 | | - # y1, b1 = rrule(CFG, accumulate, *, (1, 2, 3, 4); init=1) |
430 | | - # @test y1 == (1, 2, 6, 24) |
431 | | - # @test b1((1, 1, 1, 1)) == (NoTangent(), NoTangent(), Tangent{NTuple{4,Int}}(33, 16, 10, 6)) |
432 | | - |
433 | | - # # Finite differencing |
434 | | - # test_rrule(accumulate, *, Tuple(randn(5)); fkwargs=(; init=rand())) |
435 | | - # test_rrule(accumulate, /, Tuple(1 .+ rand(5)); check_inferred=false) |
436 | | - # end |
437 | 412 | end |
0 commit comments