Skip to content

Commit 531da8b

Browse files
authored
Merge pull request #1002 from FluxML/ox/clean
Make interface2.jl code around generating pullbacks via decomposition
2 parents 13647cd + 8d9fac7 commit 531da8b

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

src/compiler/emit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ end
9595

9696
varargs(m::Method, n) = m.isva ? n - m.nargs + 1 : nothing
9797

98-
function _lookup_grad(T)
98+
function _generate_pullback_via_decomposition(T)
9999
(m = meta(T)) === nothing && return
100100
va = varargs(m.method, length(T.parameters))
101101
forw, back = stacks!(Adjoint(IR(m), varargs = va, normalise = false), T)

src/compiler/interface2.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ end
2323

2424
hascr && return :($chain_rrule_f(ZygoteRuleConfig(ctx), f, args...))
2525

26-
g = try _lookup_grad(T) catch e e end
27-
!(g isa Tuple) && return :(f(args...), Pullback{$T}((f,)))
26+
g = try _generate_pullback_via_decomposition(T) catch e e end
27+
g === nothing && return :(f(args...), Pullback{$T}((f,)))
2828
meta, forw, _ = g
2929
argnames!(meta, Symbol("#self#"), :ctx, :f, :args)
3030
forw = varargs!(meta, forw, 3)
@@ -37,7 +37,8 @@ end
3737

3838
@generated function (j::Pullback{T})(Δ) where T
3939
ignore_sig(T) && return :nothing
40-
g = try _lookup_grad(T)
40+
g = try
41+
_generate_pullback_via_decomposition(T)
4142
catch e
4243
rethrow(CompileError(T,e))
4344
end

0 commit comments

Comments
 (0)