Skip to content
Draft
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module Zygote
using LinearAlgebra, Statistics
using LinearAlgebra: copytri!, AbstractTriangular

import ZygoteRules: @adjoint, @adjoint!, AContext, adjoint, _pullback, pullback,
import ZygoteRules: ZygoteRules, @adjoint, @adjoint!, AContext, adjoint, _pullback, pullback,
literal_getproperty, literal_getfield, unthunk_tangent

using ChainRulesCore
Expand Down
52 changes: 47 additions & 5 deletions src/compiler/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,6 @@ Convert `x` from the differentials types ChainRules uses to the format Zygote us
@inline wrap_chainrules_output(x) = x
@inline wrap_chainrules_output(x::AbstractThunk) = wrap_chainrules_output(unthunk(x)) # For now we are just not going to deal with thunks
@inline wrap_chainrules_output(x::Tuple) = map(wrap_chainrules_output, x)
# Zygote convention: even if many AbstractZero partials (i.e. multi-input function), make just 1 nothing.
@inline wrap_chainrules_output(x::Tuple{Vararg{ChainRules.AbstractZero}}) = nothing
@inline wrap_chainrules_output(x::ChainRules.AbstractZero) = nothing
@inline wrap_chainrules_output(x::ChainRulesCore.NotImplemented) = nothing
for T_outer in (:Tuple, :NamedTuple)
# we create separate methods rather than using a `Union` + an `if` so that we avoid a
# branch that changes output type, because nested AD on that kinda thing makes Zygote less
Expand All @@ -125,6 +121,8 @@ end
wrap_chainrules_output(dxs::AbstractArray{<:Number}) = dxs
wrap_chainrules_output(dxs::AbstractArray{<:AbstractArray{<:Number}}) = dxs
wrap_chainrules_output(dxs::AbstractArray) = map(wrap_chainrules_output, dxs)


#=
# As an optimisation, we can convert by `reinterpret` for bitstypes, e.g. arrays of tuples of numbers
@inline function wrap_chainrules_output(dxs::AbstractArray{<:ChainRules.Tangent{<:Any, B}}) where {B}
Expand Down Expand Up @@ -152,6 +150,7 @@ Convert `dx` from the format Zygote uses internally to differentials types Chain
@inline wrap_chainrules_input(::Nothing) = ChainRules.ZeroTangent()
@inline wrap_chainrules_input(::Tuple{Vararg{Nothing}}) = ChainRules.ZeroTangent()
@inline wrap_chainrules_input(::AbstractArray{Nothing}) = ChainRules.ZeroTangent()
@inline wrap_chainrules_input(dxs::AbstractArray{T}) where {T<:AbstractZero} = first(dxs)
@inline function wrap_chainrules_input(dxs::Union{Tuple, NamedTuple})
xp = map(wrap_chainrules_input, dxs)
# This produces Tangent{Any} since it does not get to see the primal, `x`.
Expand Down Expand Up @@ -186,9 +185,12 @@ Also handles some Zygote-specific corrections, such as `x::Array, dx::Tuple`.
Safe to apply to arbitrary input.
"""
@inline function _project(x, dx)
wrap_chainrules_output(ProjectTo(x)(zygote2differential(dx, x)))
differential2zygote(ProjectTo(x)(zygote2differential(dx, x)))
end

_project(_, dx::Nothing) = nothing
_project(x::Tuple, dx::Tuple) = map(_project, x, dx)

# Restore splatted arrays
_project(x::AbstractArray, dx::Tuple) = _project(x, reshape(collect(dx), axes(x)))

Expand Down Expand Up @@ -350,3 +352,43 @@ z2d(dx::NamedTuple{L,S}, primal::AbstractDict) where {L,S<:Tuple{Vararg{Union{Nu
end

z2d(dx::Ref, primal) = z2d(dx[], primal) # mutable structs


"""
differential2zygote(dx)

Convert input `dx` from ChainRules differential types to the Zygote format.
This is similar to `wrap_chainrules_output(dx)`, but converts zero types.
"""
@inline differential2zygote(@nospecialize(x)) = x
@inline differential2zygote(::AbstractZero) = nothing
@inline differential2zygote(::ChainRulesCore.NotImplemented) = nothing
@inline differential2zygote(x::AbstractThunk) = differential2zygote(unthunk(x)) # For now we are just not going to deal with thunks
for T_outer in (:Tuple, :NamedTuple)
# we create separate methods rather than using a `Union` + an `if` so that we avoid a
# branch that changes output type, because nested AD on that kinda thing makes Zygote less
# than happy.
@eval @inline differential2zygote(x::$T_outer) = map(differential2zygote, x)
@eval @inline function differential2zygote(x::Tangent{<:Any, <:$T_outer})
# this is accessing ChainRulesCore internals, but it is prob safe enough, and it is fastest
inner = ChainRulesCore.backing(canonicalize(x))
return differential2zygote(inner)
end
end
# Zygote convention: even if many AbstractZero partials (i.e. multi-input function), make just 1 nothing.
@inline differential2zygote(::Tuple{Vararg{AbstractZero}}) = nothing
@inline differential2zygote(::Tuple{}) = () # Edge case split off from the above method

differential2zygote(dxs::AbstractArray{<:Number}) = dxs
differential2zygote(dxs::AbstractArray{<:AbstractArray{<:Number}}) = dxs
differential2zygote(dxs::AbstractArray) = map(differential2zygote, dxs)
differential2zygote(dxs::Dict) = Dict(k => differential2zygote(v) for (k, v) in dxs)

# Mostly used in rule genfuncs
_iszerotype(T) = T === Nothing || T <: AbstractZero

# Note: safe piracy to make @adjoint definitions work
ZygoteRules.gradtuple0(x::AbstractZero) = x
ZygoteRules.gradtuple1(x::AbstractZero) = x
ZygoteRules.gradtuple2(x::AbstractZero) = x
ZygoteRules.gradtuple3(x::AbstractZero) = x
10 changes: 5 additions & 5 deletions src/compiler/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ tailmemaybe(x::Tuple) = Base.tail(x)
@inline pullback(f, args...) = pullback(f, Context(), args...)
function pullback(f, cx::AContext, args...)
y, back = _pullback(cx, f, args...)
y, Δ -> tailmemaybe(back(Δ))
wrapped_back(Δ) = tailmemaybe(differential2zygote(back(Δ)))
y, wrapped_back
end
function pullback(cx::Context, f, args...)
ChainRulesCore.ignore_derivatives() do
Expand Down Expand Up @@ -95,7 +96,7 @@ julia> gradient([7, 11], 0, 1) do x, y, d
function gradient(f, args...)
y, back = pullback(f, args...)
grad = back(sensitivity(y))
isnothing(grad) ? nothing : map(_project, args, grad)
return _project(args, grad)
end

# Base.adjoint(f::Function) = x -> gradient(f, x)[1] # piracy!
Expand Down Expand Up @@ -131,8 +132,7 @@ julia> res.grad[w]
function withgradient(f, args...)
y, back = pullback(f, args...)
grad = back(sensitivity(y))
results = isnothing(grad) ? map(_ -> nothing, args) : map(_project, args, grad)
(val=y, grad=results)
(val=y, grad=_project(args, grad))
end

# Param-style wrappers
Expand Down Expand Up @@ -184,7 +184,7 @@ Params(xs::Tuple) = Params(collect(xs))

Base.in(x, ps::Params) = x in ps.params

Base.map(::typeof(_project), args::Tuple{Params}, grad) = grad # skip _project in gradient(f, ::Params)
_project(::Tuple{Params}, grad) = grad # skip _project in gradient(f, ::Params)

function Base.union!(ps::Params, itrs...)
foreach(itr -> foreach(x -> push!(ps, x), itr), itrs)
Expand Down
1 change: 1 addition & 0 deletions src/compiler/reverse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using IRTools: IR, Variable, Pipe, xcall, var, prewalk, postwalk,
@inline tuple_va(N, xs) = xs
@inline tuple_va(N, x, xs...) = (x, tuple_va(N, xs...)...)
@inline tuple_va(::Val{N}, ::Nothing) where N = ntuple(_ -> nothing, Val(N))
@inline tuple_va(::Val{N}, x::AbstractZero) where N = ntuple(_ -> x, Val(N))

iscall(x, m::Module, n::Symbol) = isexpr(x, :call) && x.args[1] == GlobalRef(m, n)

Expand Down
3 changes: 3 additions & 0 deletions src/compiler/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,7 @@ function funcname(T)
end

Base.show(io::IO, j::Pullback{S}) where S = print(io, "∂($(funcname(S.parameters[1])))")
function Base.show(io::IO, P::Type{<:Pullback{S}}) where S
@isdefined(S) ? print(io, "Pullback{", S, ", ...}") : print(io, "Pullback{S, T}")
end

56 changes: 28 additions & 28 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,24 +39,6 @@ end
@adjoint (::Type{T})(sz) where {T<:Zeros} = T(sz), Δ->(nothing,)
@adjoint (::Type{T})(sz) where {T<:Ones} = T(sz), Δ->(nothing,)

@adjoint getindex(x::AbstractArray, inds...) = x[inds...], ∇getindex(x, inds)

@adjoint view(x::AbstractArray, inds...) = view(x, inds...), ∇getindex(x, inds)

∇getindex(x::AbstractArray{T,N}, inds) where {T,N} = dy -> begin
if inds isa NTuple{N,Int} && T <: Number
dx = OneElement(dy, inds, axes(x))
elseif inds isa NTuple{<:Any, Integer}
dx = _zero(x, typeof(dy))
dx[inds...] = dy
else
dx = _zero(x, eltype(dy))
dxv = view(dx, inds...)
dxv .= accum.(dxv, _droplike(dy, dxv))
end
return (_project(x, dx), map(_->nothing, inds)...)
end

"""
OneElement(val, ind, axes) <: AbstractArray

Expand Down Expand Up @@ -247,10 +229,10 @@ reconstruct_if_dict(x̄, _keys::Nothing) = x̄

function reconstruct_if_dict(x̄, _keys)
# This reverses `collect_if_dict`, which returns `_keys::Nothing` if x is not a Dict
@assert x̄ isa AbstractVector{<:Union{Nothing, NamedTuple{(:first,:second)}}}
@assert x̄ isa AbstractVector # {<:Union{Nothing, AbstractZero, NamedTuple{(:first,:second)}}}
# we don't compute gradients with respect to keys
# @assert all(x -> x === nothing || x[1] == 0 || x[1] === nothing, x̄)
d̄ = Dict(k => isnothing(x) ? nothing : x[2] for (x, k) in zip(x̄, _keys))
d̄ = Dict(k => x === nothing || x isa AbstractZero ? x : x[2] for (x, k) in zip(x̄, _keys))
return d̄
end

Expand Down Expand Up @@ -296,8 +278,9 @@ _ndims(x) = Base.IteratorSize(x) isa Base.HasShape ? _ndims(Base.IteratorSize(x)
nd = _ndims(xs[n])
dims = ntuple(i -> i<d ? i : i+nd, ndims(dy)-nd)
d += nd
first(dy)[n] === nothing && return nothing
init = zero.(first(dy)[n]) # allows for tuples, which accum can add:
dy_1n = first(dy)[n]
(dy_1n === nothing || dy_1n isa AbstractZero) && return dy_1n
init = zero.(dy_1n) # allows for tuples, which accum can add:
red = mapreduce(StaticGetter{n}(), accum, dy; dims=dims, init=init)
return _project(xs[n], reshape(red, axes(xs[n])))
end
Expand Down Expand Up @@ -332,8 +315,16 @@ function _pullback(cx::AContext, ::typeof(prod), f, xs::AbstractArray)
return _pullback(cx, (f, xs) -> prod(f.(xs)), f, xs)
end

@adjoint real(x::AbstractArray) = real(x), r̄ -> (real(r̄),)
@adjoint conj(x::AbstractArray) = conj(x), r̄ -> (conj(r̄),)
@adjoint function real(x::AbstractArray)
real_array_pullback(r̄::AbstractZero) = (r̄,)
real_array_pullback(r̄) = (real(r̄),)
return real(x), real_array_pullback
end
@adjoint function conj(x::AbstractArray)
conj_array_pullback(r̄::AbstractZero) = (r̄,)
conj_array_pullback(r̄) = (conj(r̄),)
return conj(x), conj_array_pullback
end
@adjoint imag(x::AbstractArray) = imag(x), ī -> (complex.(0, real.(ī)),)


Expand Down Expand Up @@ -445,6 +436,7 @@ _symmetric_back(Δ::LowerTriangular, uplo) = collect(uplo == 'U' ? transpose(Δ)

@adjoint function Symmetric(A::AbstractMatrix, uplo=:U)
S = Symmetric(A, uplo)
back(Δ::AbstractZero) = (Δ, nothing)
back(Δ::AbstractMatrix) = (_symmetric_back(Δ, S.uplo), nothing)
back(Δ::NamedTuple) = (_symmetric_back(Δ.data, S.uplo), nothing)
return S, back
Expand All @@ -469,15 +461,23 @@ end

@adjoint function LinearAlgebra.Hermitian(A::AbstractMatrix, uplo=:U)
H = Hermitian(A, uplo)
back(Δ::AbstractZero) = (Δ, nothing)
back(Δ::AbstractMatrix) = (_hermitian_back(Δ, H.uplo), nothing)
back(Δ::NamedTuple) = (_hermitian_back(Δ.data, H.uplo), nothing)
return H, back
end

@adjoint convert(::Type{R}, A::LinearAlgebra.HermOrSym{T,S}) where {T,S,R<:Array} = convert(R, A),
Δ -> (nothing, convert(S, Δ),)
@adjoint Matrix(A::LinearAlgebra.HermOrSym{T,S}) where {T,S} = Matrix(A),
Δ -> (convert(S, Δ),)
@adjoint function convert(::Type{R}, A::LinearAlgebra.HermOrSym{T,S}) where {T,S,R<:Array}
convert_Array_HermOrSym_callback(Δ::AbstractZero) = (nothing, Δ)
convert_Array_HermOrSym_callback(Δ) = (nothing, convert(S, Δ))
return convert(R, A), convert_Array_HermOrSym_callback
end

@adjoint function Matrix(A::LinearAlgebra.HermOrSym{T,S}) where {T,S}
Matrix_HermOrSym_pullback(Δ::AbstractZero) = (Δ,)
Matrix_HermOrSym_pullback(Δ) = (convert(S, Δ),)
return Matrix(A), Matrix_HermOrSym_pullback
end

@adjoint function lyap(A::AbstractMatrix, C::AbstractMatrix)
X = lyap(A, C)
Expand Down
4 changes: 4 additions & 0 deletions src/lib/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ end
@inline function _broadcast_forward(::Type{<:Dual}, out, args::Vararg{Any, N}) where {N}
valN = Val(N)
y = broadcast(x -> value(x), out)
bc_fwd_back(ȳ::AbstractZero) = ȳ
function bc_fwd_back(ȳ)
dargs = ntuple(valN) do i
unbroadcast(args[i], broadcast((y1, o1) -> y1 * partials(o1,i), ȳ, out))
Expand All @@ -297,6 +298,7 @@ end
@inline function _broadcast_forward(::Type{<:Complex}, out, args::Vararg{Any, N}) where {N}
valN = Val(N)
y = broadcast(x -> Complex(value(real(x)), value(imag(x))), out)
bc_fwd_back(ȳ::AbstractZero) = ȳ
function bc_fwd_back(ȳ)
dargs = ntuple(valN) do i
unbroadcast(args[i], broadcast((y1, o1) -> (real(y1)*partials(real(o1),i) + imag(y1)*partials(imag(o1), i)), ȳ, out))
Expand All @@ -311,6 +313,7 @@ end
@inline function _broadcast_forward_complex(::Type{<:Dual}, out, args::Vararg{Any, N}) where {N}
valN = Val(N)
y = broadcast(x -> value(x), out)
bc_fwd_back(ȳ::AbstractZero) = ȳ
function bc_fwd_back(ȳ)
dargs = ntuple(valN) do i
unbroadcast(args[i], broadcast((y1, o1) -> y1 * Complex(partials(o1, i), partials(o1, i+N)), ȳ, out))
Expand All @@ -335,6 +338,7 @@ end
@inline function _broadcast_forward_complex(::Type{<:Complex}, out, args::Vararg{Any, N}) where {N}
valN = Val(N)
y = broadcast(x -> Complex(value(real(x)), value(imag(x))), out)
bc_fwd_back(ȳ::AbstractZero) = ȳ
function bc_fwd_back(ȳ)
dargs = ntuple(valN) do i
unbroadcast(args[i], broadcast((y1, o1) -> _adjoint_complex(N, y1, o1, i), ȳ, out))
Expand Down
33 changes: 22 additions & 11 deletions src/lib/lib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,11 @@ end
first(xs), Δ -> ((Δ, drest...),)
end

@adjoint Base.tail(xs::Tuple) = tail(xs), x̄s -> ((nothing, x̄s...),)
@adjoint function Base.tail(xs::Tuple)
Tuple_tail_pullback(x̄s::AbstractZero) = (x̄s,)
Tuple_tail_pullback(x̄s) = ((nothing, x̄s...),)
return tail(xs), Tuple_tail_pullback
end

_empty(x) = length(x)
_empty(x::Union{Tuple,NamedTuple}) = map(_->nothing, x)
Expand Down Expand Up @@ -229,11 +233,14 @@ end
val = getfield(x, f)
function back(Δ)
accum_param(__context__, val, Δ) === nothing && return
# Const properties on modules are considered non-differentiable
x isa Module && isconst(x, f) && return
if isimmutable(x)
dx = (; nt_nothing(x)..., pair(Val(f), Δ, x)...)
(_project(x, dx), nothing)
else
dx = grad_mut(__context__, x)
# @show dx
dx[] = (; dx[]..., pair(Val(f), accum(getfield(dx[], f), Δ))...)
return (dx,nothing)
end
Expand Down Expand Up @@ -305,24 +312,28 @@ end
end

# TODO captured mutables + multiple calls to `back`
@generated function (back::Jnew{T,G,false})(Δ::Union{NamedTuple,Nothing,RefValue}) where {T,G}
!ismutabletype(T) && Δ == Nothing && return :nothing
Δ = G == Nothing ? :Δ :
Δ <: RefValue ? :(back.g[]) :
:(accum(back.g[], Δ))
@generated function (back::Jnew{T,G,false})(Δ::Union{NamedTuple,Nothing,RefValue,AbstractZero}) where {T,G}
!ismutabletype(T) && _iszerotype(Δ) && return :Δ
Δ = if _iszerotype(G)
elseif Δ <: RefValue
:(back.g[])
else
:(accum(back.g[], Δ))
end
quote
x̄ = $Δ
$(G == Nothing || :(back.g[] = nt_nothing($Δ)))
$(_iszerotype(G) || :(back.g[] = nt_nothing($Δ)))
(nothing, $(map(f -> :(x̄.$f), fieldnames(T))...))
end
end

@generated function (back::Jnew{T,G,true})(Δ::Union{NamedTuple,Nothing,RefValue}) where {T,G}
!ismutabletype(T) && Δ == Nothing && return :nothing
Δ = G == Nothing ? :Δ : :(back.g)
@generated function (back::Jnew{T,G,true})(Δ::Union{NamedTuple,Nothing,RefValue,AbstractZero}) where {T,G}
!ismutabletype(T) && _iszerotype(Δ) && return :Δ
Δ = _iszerotype(G) ? :Δ : :(back.g)
quote
x̄ = $Δ
$(G == Nothing || :($Δ = nt_nothing($Δ)))
$(_iszerotype(G) || :($Δ = nt_nothing($Δ)))
(nothing, ($(map(f -> :(x̄.$f), fieldnames(T))...),))
end
end
Expand Down
4 changes: 2 additions & 2 deletions test/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,8 @@ using Zygote: ZygoteRuleConfig
@test Zygote.gradient(f_notimplemented, 0.1) === (nothing,)
@test Zygote.gradient(x -> f_notimplemented(x[1]), 0.1) === (nothing,)
if isdefined(Base, :only)
@test Zygote.gradient(x -> f_notimplemented(only(x)), (0.1,)) === (nothing,)
@test Zygote.gradient(x -> f_notimplemented(only(x)), [0.1]) === (nothing,)
@test Zygote.gradient(x -> f_notimplemented(only(x)), (0.1,)) === ((nothing,),)
@test_broken Zygote.gradient(x -> f_notimplemented(only(x)), [0.1]) === (nothing,)
end
end

Expand Down
7 changes: 4 additions & 3 deletions test/gradcheck.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Zygote, Test, Random, LinearAlgebra, Statistics, SparseArrays, FillArrays,
AbstractFFTs, FFTW, Distances
AbstractFFTs, FFTW, Distances, ChainRulesCore
using Zygote: gradient
using Base.Broadcast: broadcast_shape
using Distributed: pmap, CachingPool, workers
Expand Down Expand Up @@ -38,6 +38,7 @@ _joinreim(A) = A

function _dropimaggrad(A)
back(Δ) = real(Δ)
back(Δ::AbstractZero) = Δ
back(Δ::Nothing) = nothing
return Zygote.hook(back, A)
end
Expand Down Expand Up @@ -174,11 +175,11 @@ end

# Ensure that nothings work with numeric types.
_, back = Zygote.pullback(getindex, randn(4), [1])
@test back([nothing]) == (zeros(4), nothing)
@test back([nothing]) === nothing

# Ensure that nothings work with non-numeric types.
_, back = Zygote.pullback(getindex, [randn(2) for _ in 1:3], [1])
@test back([nothing]) == (nothing, nothing)
@test back([nothing]) === nothing
end

@testset "view" begin
Expand Down
Loading