Skip to content

Commit 38ebc73

Browse files
committed
un-collapse nothings in gradient
1 parent 392a1f9 commit 38ebc73

File tree

4 files changed

+13
-6
lines changed

4 files changed

+13
-6
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ Requires = "1.1"
5757
SpecialFunctions = "1.6, 2"
5858
Statistics = "1"
5959
Tracker = "0.2"
60-
ZygoteRules = "0.2.4"
60+
ZygoteRules = "0.2.5"
6161
julia = "1.6"
6262

6363
[extras]

src/compiler/interface.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@ sensitivity(y::Complex) = error("Output is complex, so the gradient is not defin
6767
sensitivity(y::AbstractArray) = error("Output is an array, so the gradient is not defined. Perhaps you wanted jacobian.")
6868
sensitivity(y) = error("Output should be scalar; gradients are not defined for output $(repr(y))")
6969

70+
# Preserves output as tuple when gradients are collapsed
71+
_project_all(::NTuple{N}, ::Nothing) where {N} = ntuple(_ -> nothing, N)
72+
_project_all(x::Tuple, dx::Tuple) = map(_project, x, dx)
73+
7074
"""
7175
gradient(f, args...)
7276
@@ -95,7 +99,7 @@ julia> gradient([7, 11], 0, 1) do x, y, d
9599
function gradient(f, args...)
96100
y, back = pullback(f, args...)
97101
grad = back(sensitivity(y))
98-
isnothing(grad) ? nothing : map(_project, args, grad)
102+
return _project_all(args, grad)
99103
end
100104

101105
# Base.adjoint(f::Function) = x -> gradient(f, x)[1] # piracy!
@@ -161,7 +165,7 @@ function withgradient(f, args...)
161165
else
162166
back(sensitivity(y))
163167
end
164-
results = isnothing(grad) ? map(_ -> nothing, args) : map(_project, args, grad)
168+
results = _project_all(args, grad)
165169
(val=y, grad=results)
166170
end
167171

@@ -421,6 +425,9 @@ function pullback(f, ps::Params)
421425
end
422426
end
423427

428+
# No conversion required here
429+
_project_all(_, dx::Grads) = dx
430+
424431
# Code Reflection
425432

426433
function code_ir(f, T)

test/lib/number.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
@test gradient(floor, 1) === (0.0,)
44
@test gradient(ceil, 1) === (0.0,)
55
@test gradient(round, 1) === (0.0,)
6-
@test gradient(hash, 1) === nothing
7-
@test gradient(div, 1, 2) === nothing
6+
@test gradient(hash, 1) === (nothing,)
7+
@test gradient(div, 1, 2) === (nothing, nothing)
88
end
99

1010
@testset "basics" begin

test/structures.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,5 +64,5 @@ end
6464
end
6565

6666
m, b = Zygote._pullback(Zygote.Context(), nameof, M)
67-
@test b(m) == (nothing, nothing)
67+
@test b(m) === nothing
6868
end

0 commit comments

Comments
 (0)