Skip to content

Commit b170521

Browse files
authored
Merge pull request #1006 from FluxML/ox/typeonlyrrules
use rrules even when all the arguments are types
2 parents 87e2f12 + b9f186f commit b170521

File tree

2 files changed

+23
-4
lines changed

2 files changed

+23
-4
lines changed

src/compiler/interface2.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,7 @@ function edge!(m::IRTools.Meta, edge::Core.MethodInstance)
77
end
88

99
@generated function _pullback(ctx::AContext, f, args...)
10-
T = Tuple{f,args...}
11-
ignore_sig(T) && return :(f(args...), Pullback{$T}(()))
12-
10+
# Try using ChainRulesCore
1311
if is_kwfunc(f, args...)
1412
# if it is_kw then `args[1]` are the keyword args, `args[2]` is actual function
1513
cr_T = Tuple{ZygoteRuleConfig{ctx}, args[2:end]...}
@@ -20,9 +18,12 @@ end
2018
end
2119

2220
hascr, cr_edge = has_chain_rrule(cr_T)
23-
2421
hascr && return :($chain_rrule_f(ZygoteRuleConfig(ctx), f, args...))
2522

23+
# No ChainRule, going to have to work it out.
24+
T = Tuple{f,args...}
25+
ignore_sig(T) && return :(f(args...), Pullback{$T}(()))
26+
2627
g = try _generate_pullback_via_decomposition(T) catch e e end
2728
g === nothing && return :(f(args...), Pullback{$T}((f,)))
2829
meta, forw, _ = g

test/chainrules.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,24 @@ using ChainRulesCore, ChainRulesTestUtils, Zygote
214214
@test (nothing,) == Zygote.gradient(x->not_diff_kw_eg(x, 2), 10.4)
215215
@test (nothing,) == Zygote.gradient(x->not_diff_kw_eg(x, 2; kw=2.0), 10.4)
216216
end
217+
218+
@testset "Type only rrule" begin
219+
struct StructForTestingTypeOnlyRRules{T}
220+
x::T
221+
end
222+
StructForTestingTypeOnlyRRules() = StructForTestingTypeOnlyRRules(1.0)
223+
224+
function ChainRulesCore.rrule(P::Type{<:StructForTestingTypeOnlyRRules})
225+
# notice here we mess with the primal doing 2.0 rather than 1.0, this is for testing purposes
226+
# and also because apparently people actually want to do this. Weird, but 🤷
227+
# https://github.com/SciML/SciMLBase.jl/issues/69#issuecomment-865639754
228+
P(2.0), _ -> (NoTangent(),)
229+
end
230+
231+
@assert StructForTestingTypeOnlyRRules().x == 1.0
232+
aug_primal_val, _ = Zygote.pullback(x->StructForTestingTypeOnlyRRules(), 1.2)
233+
@test aug_primal_val.x == 2.0
234+
end
217235
end
218236

219237
@testset "ChainRulesCore.rrule_via_ad" begin

0 commit comments

Comments
 (0)