@@ -39,6 +39,52 @@ _pullback(f, args...) = _pullback(Context(), f, args...)
3939tailmemaybe (:: Nothing ) = nothing
4040tailmemaybe (x:: Tuple ) = Base. tail (x)
4141
42+ """
43+ pullback(f, args...)
44+ pullback(f, ::Params)
45+
46+ Returns the value of the function `f` and a back-propagator function,
47+ which can be called to obtain a tuple containing `∂f/∂x` for each argument `x`,
48+ the derivative (for scalar `x`) or gradient.
49+
50+ ```julia
51+ y, back = pullback(f, args...)
52+ ∇ = back(seed)
53+ ```
54+
55+ `back` must be called with a start value `seed` matching the output of `f(args...)`.
56+ If `f(args...)` returns a number, `seed` should be a number.
57+ If `f(args...)` returns an array, `seed` should be an equally-sized array.
58+
59+ See also [`withgradient`](@ref) to obtain the value and gradients in one call,
60+ and [`gradient`](@ref) for obtaining just the gradients.
61+
62+ ```jldoctest; setup=:(using Zygote)
63+ julia> y, back = pullback(*, 2.0, 3.0, 5.0);
64+
65+ julia> y
66+ 30.0
67+
68+ julia> back(1.0)
69+ (15.0, 10.0, 6.0)
70+
71+ julia> back(2.0)
72+ (30.0, 20.0, 12.0)
73+
74+ julia> y, back = pullback(x -> [x, x], 1.0);
75+
76+ julia> y
77+ 2-element Vector{Float64}:
78+ 1.0
79+ 1.0
80+
81+ julia> back([1.0, 1.0])
82+ (2.0,)
83+
84+ julia> back([2.0, nothing])
85+ (2.0,)
86+ ```
87+ """
4288@inline pullback (f, args... ) = pullback (f, Context (), args... )
4389function pullback (f, cx:: AContext , args... )
4490 y, back = _pullback (cx, f, args... )
@@ -67,11 +113,16 @@ sensitivity(y::Complex) = error("Output is complex, so the gradient is not defin
67113sensitivity (y:: AbstractArray ) = error (" Output is an array, so the gradient is not defined. Perhaps you wanted jacobian." )
68114sensitivity (y) = error (" Output should be scalar; gradients are not defined for output $(repr (y)) " )
69115
116+ # Preserves output as tuple when gradients are collapsed
117+ _project_all (:: NTuple{N} , :: Nothing ) where {N} = ntuple (_ -> nothing , N)
118+ _project_all (x:: Tuple , dx:: Tuple ) = map (_project, x, dx)
119+
70120"""
71121 gradient(f, args...)
72122
73123Returns a tuple containing `∂f/∂x` for each argument `x`,
74124the derivative (for scalar `x`) or the gradient.
125+ If no gradient is defined, `∂f/∂x` will be `nothing`.
75126
76127`f(args...)` must be a real number, see [`jacobian`](@ref) for array output.
77128
@@ -95,7 +146,7 @@ julia> gradient([7, 11], 0, 1) do x, y, d
95146function gradient (f, args... )
96147 y, back = pullback (f, args... )
97148 grad = back (sensitivity (y))
98- isnothing (grad) ? nothing : map (_project, args, grad)
149+ return _project_all ( args, grad)
99150end
100151
101152# Base.adjoint(f::Function) = x -> gradient(f, x)[1] # piracy!
109160 withgradient(f, ::Params)
110161
111162Returns both the value of the function and the [`gradient`](@ref),
112- as a named tuple.
163+ as a named tuple.
113164
114165```jldoctest; setup=:(using Zygote)
115166julia> y, ∇ = withgradient(/, 1, 2)
@@ -161,7 +212,7 @@ function withgradient(f, args...)
161212 else
162213 back (sensitivity (y))
163214 end
164- results = isnothing (grad) ? map (_ -> nothing , args) : map (_project, args, grad)
215+ results = _project_all ( args, grad)
165216 (val= y, grad= results)
166217end
167218
304355 Grads(...)
305356
306357Dictionary-like container returned when taking gradients with
307- respect to implicit parameters. For an array `W`, appearing
358+ respect to implicit parameters. For an array `W`, appearing
308359within `Params([W, A, B...])`, the gradient is `g[W]`.
309360"""
310361struct Grads
@@ -321,7 +372,7 @@ const ADictOrGrads = Union{AbstractDict, Grads}
321372
322373# Dictionary interface.
323374# Don't use the IdDict directly since it may contain some spurious pairs.
324- Base. haskey (gs:: Grads , x) = x ∈ gs. params
375+ Base. haskey (gs:: Grads , x) = x ∈ gs. params
325376Base. keys (gs:: Grads ) = gs. params
326377Base. values (gs:: Grads ) = (gs. grads[p] for p in gs. params)
327378
@@ -381,7 +432,7 @@ broadcasted(f, a::Numeric, gs::Grads) = map(x -> f(a, x), gs)
381432broadcasted (f, gs:: Grads , a:: Numeric ) = map (x -> f (x, a), gs)
382433
383434function materialize! (gs1:: Grads , gs2:: Grads )
384- issetequal (gs1. params, gs2. params) ||
435+ issetequal (gs1. params, gs2. params) ||
385436 throw (ArgumentError (" Expected Grads objects with the same Params." ))
386437 for p in gs1. params
387438 gs1[p] = gs2[p]
@@ -421,6 +472,9 @@ function pullback(f, ps::Params)
421472 end
422473end
423474
475+ # No conversion required here
476+ _project_all (_, dx:: Grads ) = dx
477+
424478# Code Reflection
425479
426480function code_ir (f, T)
0 commit comments