@@ -39,6 +39,52 @@ _pullback(f, args...) = _pullback(Context(), f, args...)
3939tailmemaybe (:: Nothing ) = nothing
4040tailmemaybe (x:: Tuple ) = Base. tail (x)
4141
42+ """
43+ pullback(f, args...)
44+ pullback(f, ::Params)
45+
46+ Returns the value of the function `f` and a back-propagator function,
47+ which can be called to obtain a tuple containing `∂f/∂x` for each argument `x`,
48+ the derivative (for scalar `x`) or gradient.
49+
50+ ```julia
51+ y, back = pullback(f, args...)
52+ ∇ = back(seed)
53+ ```
54+
55+ `back` must be called with a start value `seed` matching the output of `f(args...)`.
56+ If `f(args...)` returns a number, `seed` should be a number.
57+ If `f(args...)` returns an array, `seed` should be an equally-sized array.
58+
59+ See also [`withgradient`](@ref) to obtain the value and gradients in one call,
60+ and [`gradient`](@ref) for obtaining just the gradients.
61+
62+ ```jldoctest; setup=:(using Zygote)
63+ julia> y, back = pullback(*, 2.0, 3.0, 5.0);
64+
65+ julia> y
66+ 30.0
67+
68+ julia> back(1.0)
69+ (15.0, 10.0, 6.0)
70+
71+ julia> back(2.0)
72+ (30.0, 20.0, 12.0)
73+
74+ julia> y, back = pullback(x -> [x, x], 1.0);
75+
76+ julia> y
77+ 2-element Vector{Float64}:
78+ 1.0
79+ 1.0
80+
81+ julia> back([1.0, 1.0])
82+ (2.0,)
83+
84+ julia> back([2.0, nothing])
85+ (2.0,)
86+ ```
87+ """
4288@inline pullback (f, args... ) = pullback (f, Context (), args... )
4389function pullback (f, cx:: AContext , args... )
4490 y, back = _pullback (cx, f, args... )
@@ -76,6 +122,7 @@ _project_all(x::Tuple, dx::Tuple) = map(_project, x, dx)
76122
77123Returns a tuple containing `∂f/∂x` for each argument `x`,
78124the derivative (for scalar `x`) or the gradient.
125+ If no gradient is defined, `∂f/∂x` will be `nothing`.
79126
80127`f(args...)` must be a real number, see [`jacobian`](@ref) for array output.
81128
113160 withgradient(f, ::Params)
114161
115162Returns both the value of the function and the [`gradient`](@ref),
116- as a named tuple.
163+ as a named tuple.
117164
118165```jldoctest; setup=:(using Zygote)
119166julia> y, ∇ = withgradient(/, 1, 2)
308355 Grads(...)
309356
310357Dictionary-like container returned when taking gradients with
311- respect to implicit parameters. For an array `W`, appearing
358+ respect to implicit parameters. For an array `W`, appearing
312359within `Params([W, A, B...])`, the gradient is `g[W]`.
313360"""
314361struct Grads
@@ -325,7 +372,7 @@ const ADictOrGrads = Union{AbstractDict, Grads}
325372
326373# Dictionary interface.
327374# Don't use the IdDict directly since it may contain some spurious pairs.
328- Base. haskey (gs:: Grads , x) = x ∈ gs. params
375+ Base. haskey (gs:: Grads , x) = x ∈ gs. params
329376Base. keys (gs:: Grads ) = gs. params
330377Base. values (gs:: Grads ) = (gs. grads[p] for p in gs. params)
331378
@@ -385,7 +432,7 @@ broadcasted(f, a::Numeric, gs::Grads) = map(x -> f(a, x), gs)
385432broadcasted (f, gs:: Grads , a:: Numeric ) = map (x -> f (x, a), gs)
386433
387434function materialize! (gs1:: Grads , gs2:: Grads )
388- issetequal (gs1. params, gs2. params) ||
435+ issetequal (gs1. params, gs2. params) ||
389436 throw (ArgumentError (" Expected Grads objects with the same Params." ))
390437 for p in gs1. params
391438 gs1[p] = gs2[p]
0 commit comments