Skip to content

Commit 54f1e80

Browse files
authored
Merge pull request #1495 from FluxML/bc/gradtuple
Un-collapse nothings in `gradient`
2 parents 5b61724 + 070466d commit 54f1e80

File tree

4 files changed

+64
-10
lines changed

4 files changed

+64
-10
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ Requires = "1.1"
5757
SpecialFunctions = "1.6, 2"
5858
Statistics = "1"
5959
Tracker = "0.2"
60-
ZygoteRules = "0.2.4"
60+
ZygoteRules = "0.2.5"
6161
julia = "1.6"
6262

6363
[extras]

src/compiler/interface.jl

Lines changed: 60 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,52 @@ _pullback(f, args...) = _pullback(Context(), f, args...)
3939
tailmemaybe(::Nothing) = nothing
4040
tailmemaybe(x::Tuple) = Base.tail(x)
4141

42+
"""
43+
pullback(f, args...)
44+
pullback(f, ::Params)
45+
46+
Returns the value of the function `f` and a back-propagator function,
47+
which can be called to obtain a tuple containing `∂f/∂x` for each argument `x`,
48+
the derivative (for scalar `x`) or gradient.
49+
50+
```julia
51+
y, back = pullback(f, args...)
52+
∇ = back(seed)
53+
```
54+
55+
`back` must be called with a start value `seed` matching the output of `f(args...)`.
56+
If `f(args...)` returns a number, `seed` should be a number.
57+
If `f(args...)` returns an array, `seed` should be an equally-sized array.
58+
59+
See also [`withgradient`](@ref) to obtain the value and gradients in one call,
60+
and [`gradient`](@ref) for obtaining just the gradients.
61+
62+
```jldoctest; setup=:(using Zygote)
63+
julia> y, back = pullback(*, 2.0, 3.0, 5.0);
64+
65+
julia> y
66+
30.0
67+
68+
julia> back(1.0)
69+
(15.0, 10.0, 6.0)
70+
71+
julia> back(2.0)
72+
(30.0, 20.0, 12.0)
73+
74+
julia> y, back = pullback(x -> [x, x], 1.0);
75+
76+
julia> y
77+
2-element Vector{Float64}:
78+
1.0
79+
1.0
80+
81+
julia> back([1.0, 1.0])
82+
(2.0,)
83+
84+
julia> back([2.0, nothing])
85+
(2.0,)
86+
```
87+
"""
4288
@inline pullback(f, args...) = pullback(f, Context(), args...)
4389
function pullback(f, cx::AContext, args...)
4490
y, back = _pullback(cx, f, args...)
@@ -67,11 +113,16 @@ sensitivity(y::Complex) = error("Output is complex, so the gradient is not defin
67113
sensitivity(y::AbstractArray) = error("Output is an array, so the gradient is not defined. Perhaps you wanted jacobian.")
68114
sensitivity(y) = error("Output should be scalar; gradients are not defined for output $(repr(y))")
69115

116+
# Preserves output as tuple when gradients are collapsed
117+
_project_all(::NTuple{N}, ::Nothing) where {N} = ntuple(_ -> nothing, N)
118+
_project_all(x::Tuple, dx::Tuple) = map(_project, x, dx)
119+
70120
"""
71121
gradient(f, args...)
72122
73123
Returns a tuple containing `∂f/∂x` for each argument `x`,
74124
the derivative (for scalar `x`) or the gradient.
125+
If no gradient is defined, `∂f/∂x` will be `nothing`.
75126
76127
`f(args...)` must be a real number, see [`jacobian`](@ref) for array output.
77128
@@ -95,7 +146,7 @@ julia> gradient([7, 11], 0, 1) do x, y, d
95146
function gradient(f, args...)
96147
y, back = pullback(f, args...)
97148
grad = back(sensitivity(y))
98-
isnothing(grad) ? nothing : map(_project, args, grad)
149+
return _project_all(args, grad)
99150
end
100151

101152
# Base.adjoint(f::Function) = x -> gradient(f, x)[1] # piracy!
@@ -109,7 +160,7 @@ end
109160
withgradient(f, ::Params)
110161
111162
Returns both the value of the function and the [`gradient`](@ref),
112-
as a named tuple.
163+
as a named tuple.
113164
114165
```jldoctest; setup=:(using Zygote)
115166
julia> y, ∇ = withgradient(/, 1, 2)
@@ -161,7 +212,7 @@ function withgradient(f, args...)
161212
else
162213
back(sensitivity(y))
163214
end
164-
results = isnothing(grad) ? map(_ -> nothing, args) : map(_project, args, grad)
215+
results = _project_all(args, grad)
165216
(val=y, grad=results)
166217
end
167218

@@ -304,7 +355,7 @@ end
304355
Grads(...)
305356
306357
Dictionary-like container returned when taking gradients with
307-
respect to implicit parameters. For an array `W`, appearing
358+
respect to implicit parameters. For an array `W`, appearing
308359
within `Params([W, A, B...])`, the gradient is `g[W]`.
309360
"""
310361
struct Grads
@@ -321,7 +372,7 @@ const ADictOrGrads = Union{AbstractDict, Grads}
321372

322373
# Dictionary interface.
323374
# Don't use the IdDict directly since it may contain some spurious pairs.
324-
Base.haskey(gs::Grads, x) = x gs.params
375+
Base.haskey(gs::Grads, x) = x gs.params
325376
Base.keys(gs::Grads) = gs.params
326377
Base.values(gs::Grads) = (gs.grads[p] for p in gs.params)
327378

@@ -381,7 +432,7 @@ broadcasted(f, a::Numeric, gs::Grads) = map(x -> f(a, x), gs)
381432
broadcasted(f, gs::Grads, a::Numeric) = map(x -> f(x, a), gs)
382433

383434
function materialize!(gs1::Grads, gs2::Grads)
384-
issetequal(gs1.params, gs2.params) ||
435+
issetequal(gs1.params, gs2.params) ||
385436
throw(ArgumentError("Expected Grads objects with the same Params."))
386437
for p in gs1.params
387438
gs1[p] = gs2[p]
@@ -421,6 +472,9 @@ function pullback(f, ps::Params)
421472
end
422473
end
423474

475+
# No conversion required here
476+
_project_all(_, dx::Grads) = dx
477+
424478
# Code Reflection
425479

426480
function code_ir(f, T)

test/lib/number.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
@test gradient(floor, 1) === (0.0,)
44
@test gradient(ceil, 1) === (0.0,)
55
@test gradient(round, 1) === (0.0,)
6-
@test gradient(hash, 1) === nothing
7-
@test gradient(div, 1, 2) === nothing
6+
@test gradient(hash, 1) === (nothing,)
7+
@test gradient(div, 1, 2) === (nothing, nothing)
88
end
99

1010
@testset "basics" begin

test/structures.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,5 +64,5 @@ end
6464
end
6565

6666
m, b = Zygote._pullback(Zygote.Context(), nameof, M)
67-
@test b(m) == (nothing, nothing)
67+
@test b(m) === nothing
6868
end

0 commit comments

Comments
 (0)