Skip to content

Commit bc78d43

Browse files
authored
Merge branch 'master' into bc/rm-boundscheck
2 parents efcc64b + f24b9b2 commit bc78d43

File tree

7 files changed

+51
-47
lines changed

7 files changed

+51
-47
lines changed

Project.toml

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Zygote"
22
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
3-
version = "0.6.65"
3+
version = "0.6.67"
44

55
[deps]
66
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
@@ -18,14 +18,24 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1818
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
1919
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
2020
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
21+
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
2122
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2223
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
23-
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
2424
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2525
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
2626
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2727
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
2828

29+
[weakdeps]
30+
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
31+
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
32+
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
33+
34+
[extensions]
35+
ZygoteColorsExt = "Colors"
36+
ZygoteDistancesExt = "Distances"
37+
ZygoteTrackerExt = "Tracker"
38+
2939
[compat]
3040
AbstractFFTs = "1.3.1"
3141
ChainRules = "1.44.1"
@@ -38,22 +48,18 @@ FillArrays = "0.8, 0.9, 0.10, 0.11, 0.12, 0.13, 1"
3848
ForwardDiff = "0.10"
3949
GPUArrays = "8.4.2, 9"
4050
GPUArraysCore = "0.1.1"
41-
IRTools = "0.4.4"
51+
IRTools = "0.4.11"
4252
LogExpFunctions = "0.3.1"
4353
MacroTools = "0.5"
4454
NaNMath = "0.3, 1"
45-
Requires = "1.1"
4655
PrecompileTools = "1"
56+
Requires = "1.1"
4757
SpecialFunctions = "1.6, 2"
58+
Statistics = "1"
4859
Tracker = "0.2"
49-
ZygoteRules = "0.2.1"
60+
ZygoteRules = "0.2.4"
5061
julia = "1.6"
5162

52-
[extensions]
53-
ZygoteColorsExt = "Colors"
54-
ZygoteDistancesExt = "Distances"
55-
ZygoteTrackerExt = "Tracker"
56-
5763
[extras]
5864
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
5965
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
@@ -68,8 +74,3 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
6874

6975
[targets]
7076
test = ["ChainRulesTestUtils", "Conda", "CUDA", "Distances", "FFTW", "FiniteDifferences", "PyCall", "Test"]
71-
72-
[weakdeps]
73-
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
74-
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
75-
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"

src/compiler/reverse.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,6 @@ Variable(a::Alpha) = Variable(a.id)
248248
sig(b::IRTools.Block) = unique([arg for br in branches(b) for arg in br.args if arg isa Variable])
249249
sig(pr::Primal) = Dict(b.id => sig(b) for b in blocks(pr.ir))
250250

251-
# TODO unreachables?
252251
function adjointcfg(pr::Primal)
253252
ir = empty(pr.ir)
254253
return!(ir, nothing)
@@ -261,7 +260,9 @@ function adjointcfg(pr::Primal)
261260
push!(rb, xcall(Base, :(!==), alpha(pr.branches[b.id]), BranchNumber(i)))
262261
branch!(rb, preds[i].id, unless = cond)
263262
end
264-
if !isempty(branches(b)) && branches(b)[end] == IRTools.unreachable
263+
if isempty(preds) || (!isempty(branches(b)) && branches(b)[end] == IRTools.unreachable)
264+
# If `b` is unreachable, then no context produced by the primal should end up branching to `rb`
265+
push!(rb, xcall(Core, :throw, "unreachable")) # `throw` is necessary for inference not to hit the `unreachable`
265266
branch!(rb, 0)
266267
end
267268
end

src/deprecated.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,9 @@ macro nograd(ex)
6565
end
6666
return blk
6767
end
68+
69+
# Internal function used by some downstream packages.
70+
# Removing this completely would require some tricky registry changes,
71+
# but leaving it as a vestigial function is much easier.
72+
# See https://github.com/FluxML/Zygote.jl/pull/1328 for more context.
73+
function ∇getindex end

src/lib/array.jl

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -41,24 +41,6 @@ end
4141
@adjoint (::Type{T})(sz) where {T<:Zeros} = T(sz), Δ->(nothing,)
4242
@adjoint (::Type{T})(sz) where {T<:Ones} = T(sz), Δ->(nothing,)
4343

44-
@adjoint getindex(x::AbstractArray, inds...) = x[inds...], ∇getindex(x, inds)
45-
46-
@adjoint view(x::AbstractArray, inds...) = view(x, inds...), ∇getindex(x, inds)
47-
48-
∇getindex(x::AbstractArray{T,N}, inds) where {T,N} = dy -> begin
49-
if inds isa NTuple{N,Int} && T <: Number
50-
dx = OneElement(dy, inds, axes(x))
51-
elseif inds isa NTuple{<:Any, Integer}
52-
dx = _zero(x, typeof(dy))
53-
dx[inds...] = dy
54-
else
55-
dx = _zero(x, eltype(dy))
56-
dxv = view(dx, inds...)
57-
dxv .= accum.(dxv, _droplike(dy, dxv))
58-
end
59-
return (_project(x, dx), map(_->nothing, inds)...)
60-
end
61-
6244
"""
6345
OneElement(val, ind, axes) <: AbstractArray
6446

test/compiler.jl

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,10 +227,30 @@ end
227227
@test gradient(x -> sum(norm, collect(eachcol(x))), ones(3, 400))[1] fill(0.5773502691896258, 3, 400)
228228

229229
# Tests adapted from https://github.com/dfdx/Umlaut.jl/pull/35
230-
@eval _boundscheck_foo(x) = ifelse($(Expr(:boundscheck)), 2x, x)
230+
@eval _has_boundscheck(x) = ifelse($(Expr(:boundscheck)), 2x, x)
231231

232232
@testset "Meta Expr handling" begin
233-
y, (dx,) = withgradient(_boundscheck_foo, 1)
233+
y, (dx,) = withgradient(_has_boundscheck, 1)
234234
@test y == 2
235235
@test dx == 2
236236
end
237+
238+
# issue 1118 & 1380
239+
function f_1380(x)
240+
if rand(Bool)
241+
return x
242+
else
243+
return 2x
244+
end
245+
246+
# unreachable
247+
return nothing
248+
end
249+
250+
@testset "unreachable block" begin
251+
y, back = Zygote.pullback(f_1380, 1.)
252+
# There should not be a compiler error
253+
local g
254+
@test_nowarn g = back(1.)
255+
@test only(g) (1., 2.)
256+
end

test/gradcheck.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,11 +174,11 @@ end
174174

175175
# Ensure that nothings work with numeric types.
176176
_, back = Zygote.pullback(getindex, randn(4), [1])
177-
@test back([nothing]) == (zeros(4), nothing)
177+
@test back([nothing]) === nothing
178178

179179
# Ensure that nothings work with non-numeric types.
180180
_, back = Zygote.pullback(getindex, [randn(2) for _ in 1:3], [1])
181-
@test back([nothing]) == (nothing, nothing)
181+
@test back([nothing]) === nothing
182182
end
183183

184184
@testset "view" begin

test/utils.jl

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,8 @@ using ForwardDiff
22
using Zygote: hessian_dual, hessian_reverse
33

44
@testset "hessian: $hess" for hess in [hessian_dual, hessian_reverse]
5-
6-
if hess == hessian_dual
7-
@test hess(x -> x[1]*x[2], randn(2)) [0 1; 1 0]
8-
@test hess(((x,y),) -> x*y, randn(2)) [0 1; 1 0] # original docstring version
9-
else
10-
@test_broken hess(x -> x[1]*x[2], randn(2)) [0 1; 1 0] # can't differentiate ∇getindex
11-
@test_broken hess(((x,y),) -> x*y, randn(2)) [0 1; 1 0]
12-
end
5+
@test hess(x -> x[1]*x[2], randn(2)) [0 1; 1 0]
6+
@test hess(((x,y),) -> x*y, randn(2)) [0 1; 1 0] # original docstring version
137
@test hess(x -> sum(x.^3), [1 2; 3 4]) Diagonal([6, 18, 12, 24])
148
@test hess(sin, pi/2) -1
159

0 commit comments

Comments
 (0)