Skip to content

Commit 070466d

Browse files
committed
Add docs
Also adds a docstring for `pullback`, which we've been missing for some time.
1 parent 38ebc73 commit 070466d

File tree

1 file changed

+51
-4
lines changed

1 file changed

+51
-4
lines changed

src/compiler/interface.jl

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,52 @@ _pullback(f, args...) = _pullback(Context(), f, args...)
3939
tailmemaybe(::Nothing) = nothing
4040
tailmemaybe(x::Tuple) = Base.tail(x)
4141

42+
"""
43+
pullback(f, args...)
44+
pullback(f, ::Params)
45+
46+
Returns the value of the function `f` and a back-propagator function,
47+
which can be called to obtain a tuple containing `∂f/∂x` for each argument `x`,
48+
the derivative (for scalar `x`) or gradient.
49+
50+
```julia
51+
y, back = pullback(f, args...)
52+
∇ = back(seed)
53+
```
54+
55+
`back` must be called with a start value `seed` matching the output of `f(args...)`.
56+
If `f(args...)` returns a number, `seed` should be a number.
57+
If `f(args...)` returns an array, `seed` should be an equally-sized array.
58+
59+
See also [`withgradient`](@ref) to obtain the value and gradients in one call,
60+
and [`gradient`](@ref) for obtaining just the gradients.
61+
62+
```jldoctest; setup=:(using Zygote)
63+
julia> y, back = pullback(*, 2.0, 3.0, 5.0);
64+
65+
julia> y
66+
30.0
67+
68+
julia> back(1.0)
69+
(15.0, 10.0, 6.0)
70+
71+
julia> back(2.0)
72+
(30.0, 20.0, 12.0)
73+
74+
julia> y, back = pullback(x -> [x, x], 1.0);
75+
76+
julia> y
77+
2-element Vector{Float64}:
78+
1.0
79+
1.0
80+
81+
julia> back([1.0, 1.0])
82+
(2.0,)
83+
84+
julia> back([2.0, nothing])
85+
(2.0,)
86+
```
87+
"""
4288
@inline pullback(f, args...) = pullback(f, Context(), args...)
4389
function pullback(f, cx::AContext, args...)
4490
y, back = _pullback(cx, f, args...)
@@ -76,6 +122,7 @@ _project_all(x::Tuple, dx::Tuple) = map(_project, x, dx)
76122
77123
Returns a tuple containing `∂f/∂x` for each argument `x`,
78124
the derivative (for scalar `x`) or the gradient.
125+
If no gradient is defined, `∂f/∂x` will be `nothing`.
79126
80127
`f(args...)` must be a real number, see [`jacobian`](@ref) for array output.
81128
@@ -113,7 +160,7 @@ end
113160
withgradient(f, ::Params)
114161
115162
Returns both the value of the function and the [`gradient`](@ref),
116-
as a named tuple.
163+
as a named tuple.
117164
118165
```jldoctest; setup=:(using Zygote)
119166
julia> y, ∇ = withgradient(/, 1, 2)
@@ -308,7 +355,7 @@ end
308355
Grads(...)
309356
310357
Dictionary-like container returned when taking gradients with
311-
respect to implicit parameters. For an array `W`, appearing
358+
respect to implicit parameters. For an array `W`, appearing
312359
within `Params([W, A, B...])`, the gradient is `g[W]`.
313360
"""
314361
struct Grads
@@ -325,7 +372,7 @@ const ADictOrGrads = Union{AbstractDict, Grads}
325372

326373
# Dictionary interface.
327374
# Don't use the IdDict directly since it may contain some spurious pairs.
328-
Base.haskey(gs::Grads, x) = x gs.params
375+
Base.haskey(gs::Grads, x) = x gs.params
329376
Base.keys(gs::Grads) = gs.params
330377
Base.values(gs::Grads) = (gs.grads[p] for p in gs.params)
331378

@@ -385,7 +432,7 @@ broadcasted(f, a::Numeric, gs::Grads) = map(x -> f(a, x), gs)
385432
broadcasted(f, gs::Grads, a::Numeric) = map(x -> f(x, a), gs)
386433

387434
function materialize!(gs1::Grads, gs2::Grads)
388-
issetequal(gs1.params, gs2.params) ||
435+
issetequal(gs1.params, gs2.params) ||
389436
throw(ArgumentError("Expected Grads objects with the same Params."))
390437
for p in gs1.params
391438
gs1[p] = gs2[p]

0 commit comments

Comments
 (0)