Skip to content

Commit 18a6f2a

Browse files
bors[bot]oxinabox
andauthored
Merge #1004
1004: ensure `sum(f,x)` works on GPU r=DhairyaLGandhi a=oxinabox Seems like there is some pains from #990 re:GPU. In particular we broke DiffEqSensitivity https://buildkite.com/julialang/diffeqsensitivity-dot-jl/builds/169#d254017e-e824-4d9c-854d-f3b348395599/411-877 @ChrisRackauckas 's "M"WE is ``` using DiffEqFlux, OrdinaryDiffEq, DiffEqSensitivity using CUDA, Test, Zygote CUDA.allowscalar(false) H = CuArray(rand(Float32, 2, 2)) ann = FastChain(FastDense(1, 4, tanh)) p = initial_params(ann) function func(x, p, t) ann([t],p)[1]*H*x end x0 = CuArray(rand(Float32, 2)) x1 = CuArray(rand(Float32, 2)) prob = ODEProblem(func, x0, (0.0f0, 1.0f0)) function evolve(p) solve(prob, Tsit5(), p=p, save_start=false, save_everystep=false, abstol=1e-4, reltol=1e-4, sensealg=QuadratureAdjoint(autojacvec=ZygoteVJP())).u[1] end function cost(p) x = evolve(p) c = sum(abs,x - x1) #println(c) c end grad = Zygote.gradient(cost,p)[1] @test !iszero(grad[1]) @test iszero(grad[2:4]) @test !iszero(grad[5]) @test iszero(grad[6:end]) ``` I am hoping we can get it to fail with just `sum(f, xs)` (which I have added to tests)} I can't run GPU locally which makes testing this hard. If I have to I will spin up an EC2 instance, but I would really rather not. I think what is going on is, from looking at [the logs](https://buildkite.com/julialang/diffeqsensitivity-dot-jl/builds/169#d254017e-e824-4d9c-854d-f3b348395599/411-877) The error happens in during the forward pass. In particular here https://github.com/JuliaDiff/ChainRules.jl/blob/52a0eeadf8d19bff491f224517b7b064ce1ba378/src/rulesets/Base/mapreduce.jl#L46 I think this was why Zygote implemented the pullback of sum(f, x) as sum(f.(x)) (which is slower and more allocate-y than our never version) so that it could hit the code that Zygote has special for CUDA that does forwards-mode. (Which means it doesn't need the Context object containing the IdDict) So I think the solution in short-term is probably to add the old rule for sum back in (but for CuArray only) here. https://github.com/FluxML/Zygote.jl/blob/531da8bb7753f46294bc13f9d2a2fdd54917f926/src/lib/broadcast.jl#L244 ``` # Make sure sum(f, ::CuArray) uses forward mode broadcast AD defined above # Not the ChainRules.rrule which will use the Zygote.Context and thus not be GPU safe @adjoint function sum(f, xs::CuArray; kws...) @Assert !haskey(kws, :init) # TODO add init support (julia 1.6) return pullback(__context__, (f, xs) -> sum(f.(xs); kws...), f, xs) end ``` In the longer-term, we will probably default to doing the f from sum(f, xs) in forward-mode anyway. So Zygote's rule config can be updated to say that it does use ForwardDiff.jl for it's frule_via_ad. Co-authored-by: Lyndon White <[email protected]>
2 parents 9759239 + bd8c5fb commit 18a6f2a

File tree

3 files changed

+20
-2
lines changed

3 files changed

+20
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Zygote"
22
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
3-
version = "0.6.13"
3+
version = "0.6.14"
44

55
[deps]
66
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"

src/lib/broadcast.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,14 @@ end
254254
placeholder = similar(xs)
255255
sum(xs, dims = dims), Δ -> (placeholder .= Δ,)
256256
end
257-
257+
258+
# Make sure sum(f, ::CuArray) uses broadcase through forward-mode defined above
259+
# Not the ChainRules.rrule which will use the Zygote.Context and thus not be GPU compatible
260+
@adjoint function sum(f, xs::CUDA.CuArray; kws...)
261+
@assert !haskey(kws, :init) # TODO add init support (julia 1.6)
262+
return pullback(__context__, (f, xs) -> sum(f.(xs); kws...), f, xs)
263+
end
264+
258265
@adjoint function Base.convert(::Type{T}, xs::Array) where {T<:CUDA.CuArray}
259266
Base.convert(T, xs), Δ -> (nothing, Base.convert(Array, Δ),)
260267
end

test/cuda.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,17 @@ end
2626
@test g_gpu |> collect g
2727
end
2828

29+
@testset "sum(f, x)" begin
30+
a = Float32.([-1.5, -9.0, 2.4, -1.3, 0.01])
31+
a_gpu = a |> cu
32+
33+
f(x) = sum(abs, x)
34+
g = gradient(f, a)[1]
35+
g_gpu = gradient(f, a_gpu)[1]
36+
@test g_gpu isa CuArray
37+
@test g_gpu |> collect g
38+
end
39+
2940
@testset "jacobian" begin
3041
v1 = cu(collect(1:3f0))
3142

0 commit comments

Comments
 (0)