@@ -500,14 +500,45 @@ end
500500 @test 150_000_000 > @allocated gradient (loss, ones (1000 ,1000 ))
501501end
502502
503- @testset " tuples & broadcasting" begin
504- @test gradient (x -> sum (x .+ ones (2 ,2 )), (1 ,2 )) == ((2 ,2 ),)
505- @test gradient (x -> sum (x .+ ones (2 ,2 )), (1 ,)) == ((4 ,),)
506- @test gradient (x -> sum (x .+ ones (2 ,1 )), (1 ,2 )) == ((1 ,1 ),)
507-
508- # https://github.com/FluxML/Zygote.jl/issues/975
509- gt = gradient ((x,p) -> prod (x .^ p), [3 ,4 ], (1 ,2 ))
510- gv = gradient ((x,p) -> prod (x .^ p), [3 ,4 ], [1 ,2 ])
511- @test gt[1 ] == gv[1 ]
512- @test collect (gt[2 ]) ≈ gv[2 ]
503+ @testset " tricky broadcasting" begin
504+ @test gradient (x -> sum (x .+ ones (2 ,2 )), (1 ,2 )) == ((2 ,2 ),)
505+ @test gradient (x -> sum (x .+ ones (2 ,2 )), (1 ,)) == ((4 ,),)
506+ @test gradient (x -> sum (x .+ ones (2 ,1 )), (1 ,2 )) == ((1 ,1 ),)
507+
508+ # https://github.com/FluxML/Zygote.jl/issues/975
509+ gt = gradient ((x,p) -> prod (x .^ p), [3 ,4 ], (1 ,2 ))
510+ gv = gradient ((x,p) -> prod (x .^ p), [3 ,4 ], [1 ,2 ])
511+ @test gt[1 ] == gv[1 ]
512+ @test collect (gt[2 ]) ≈ gv[2 ]
513+
514+ # closure captures y -- can't use ForwardDiff
515+ @test gradient ((x,y) -> sum ((z-> z^ 2 + y[1 ]). (x)), [1 ,2 ,3 ], [4 ,5 ]) == ([2 , 4 , 6 ], [3 , 0 ])
516+ @test gradient ((x,y) -> sum ((z-> z^ 2 + y[1 ]), x), [1 ,2 ,3 ], [4 ,5 ]) == ([2 , 4 , 6 ], [3 , 0 ])
517+ @test gradient ((x,y) -> sum (map ((z-> z^ 2 + y[1 ]), x)), [1 ,2 ,3 ], [4 ,5 ]) == ([2 , 4 , 6 ], [3 , 0 ])
518+ @test gradient ((x,y) -> mapreduce ((z-> z^ 2 + y[1 ]), + , x), [1 ,2 ,3 ], [4 ,5 ]) == ([2 , 4 , 6 ], [3 , 0 ])
519+
520+ # type unstable
521+ @test gradient (xs -> sum ((x -> x< 2 ? false : x^ 2 ). (xs)), [1 ,2 ,3 ])[1 ][2 : 3 ] == [4 , 6 ]
522+ @test gradient (xs -> sum ((x -> x< 2 ? false : x^ 2 ), xs), [1 ,2 ,3 ])[1 ][2 : 3 ] == [4 , 6 ]
523+ @test gradient (xs -> sum (map ((x -> x< 2 ? false : x^ 2 ), xs)), [1 ,2 ,3 ])[1 ][2 : 3 ] == [4 , 6 ]
524+ @test gradient (xs -> mapreduce ((x -> x< 2 ? false : x^ 2 ), + , xs), [1 ,2 ,3 ])[1 ][2 : 3 ] == [4 , 6 ]
525+
526+ # with Ref, Val, Symbol
527+ @test gradient (x -> sum (x .+ Ref (x[1 ])), [1 ,2 ,3 ]) == ([4 ,1 ,1 ],)
528+ @test gradient (x -> sum (x .+ (x[1 ],)), [1 ,2 ,3 ]) == ([4 ,1 ,1 ],)
529+ @test gradient (x -> sum ((first∘ tuple). (x, :ignore )), [1 ,2 ,3 ]) == ([1 ,1 ,1 ],)
530+ @test gradient (x -> sum ((first∘ tuple). (x, Symbol)), [1 ,2 ,3 ]) == ([1 ,1 ,1 ],)
531+ _f (x,:: Val{y} = Val (2 )) where {y} = x/ y
532+ @test gradient (x -> sum (_f .(x, Val (2 ))), [1 ,2 ,3 ]) == ([0.5 , 0.5 , 0.5 ],)
533+ @test gradient (x -> sum (_f .(x)), [1 ,2 ,3 ]) == ([0.5 , 0.5 , 0.5 ],)
534+ @test gradient (x -> sum (map (_f, x)), [1 ,2 ,3 ]) == ([0.5 , 0.5 , 0.5 ],)
535+
536+ @test gradient (x -> sum (x ./ [1 ,2 ,4 ]), [1 ,2 ,pi ]) == ([1.0 , 0.5 , 0.25 ],)
537+ @test gradient (x -> sum (map (/ , x, [1 ,2 ,4 ])), [1 ,2 ,pi ]) == ([1.0 , 0.5 , 0.25 ],)
538+
539+ # negative powers
540+ @test gradient ((x,p) -> sum (x .^ p), [1.0 ,2.0 ,4.0 ], [1 ,- 1 ,2 ])[1 ] ≈ [1.0 , - 0.25 , 8.0 ]
541+ @test gradient ((x,p) -> sum (x .^ p), [1.0 ,2.0 ,4.0 ], - 1 )[1 ] ≈ [- 1.0 , - 0.25 , - 0.0625 ]
542+ @test gradient ((x,p) -> sum (z -> z^ p, x), [1.0 ,2.0 ,4.0 ], - 1 )[1 ] ≈ [- 1.0 , - 0.25 , - 0.0625 ]
543+ @test gradient ((x,p) -> mapreduce (z -> z^ p, + , x), [1.0 ,2.0 ,4.0 ], - 1 )[1 ] ≈ [- 1.0 , - 0.25 , - 0.0625 ]
513544end
0 commit comments