Skip to content

Commit 7033b82

Browse files
authored
Merge pull request #6 from JuliaDiff/ox/unionall
Handle rrule(f, x::T) where T
2 parents f73a842 + bcb8400 commit 7033b82

File tree

3 files changed

+23
-9
lines changed

3 files changed

+23
-9
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesOverloadGeneration"
22
uuid = "f51149dc-2911-5acf-81fc-2076a2a81d4f"
3-
version = "0.1.2"
3+
version = "0.1.3"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/ruleset_loading.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,8 @@ _is_fallback(::typeof(rrule), m::Method) = m.sig === Tuple{typeof(rrule),Any,Var
6666
_is_fallback(::typeof(frule), m::Method) = m.sig === Tuple{typeof(frule),Any,Any,Vararg{Any}}
6767

6868
"check if this rule requires a particular configuation (`RuleConfig`)"
69-
_requires_config(m::Method) = m.sig.parameters[2] <: RuleConfig
69+
_requires_config(m::Method) = m.sig <: Tuple{Any, RuleConfig, Vararg}
70+
7071

7172
const LAST_REFRESH_RRULE = Ref(0)
7273
const LAST_REFRESH_FRULE = Ref(0)

test/ruleset_loading.jl

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -90,14 +90,27 @@
9090
end
9191

9292
@testset "should not have rrules that need RuleConfig" begin
93-
old_rrule_list = collect(_rule_list(rrule))
94-
function ChainRulesCore.rrule(
95-
::RuleConfig{>:Union{HasForwardsMode,HasReverseMode}}, sum, f, xs
96-
)
97-
return 1.0, x->(x,x,x) # this will not be call so return doesn't matter
93+
@testset "normal type sigs" begin
94+
old_rrule_list = collect(_rule_list(rrule))
95+
function ChainRulesCore.rrule(
96+
::RuleConfig{>:Union{HasForwardsMode,HasReverseMode}}, sum, f, xs
97+
)
98+
return 1.0, x->(x,x,x) # this will not be call so return doesn't matter
99+
end
100+
# New rule should not have appeared
101+
@test collect(_rule_list(rrule)) == old_rrule_list
102+
end
103+
@testset "UnionAll type sigs" begin
104+
old_rrule_list = collect(_rule_list(rrule))
105+
function ChainRulesCore.rrule(
106+
::RuleConfig{>:Union{HasForwardsMode,HasReverseMode}}, sum, f::F, xs
107+
) where F <: Function
108+
return 1.0, x->(x,x,x) # this will not be call so return doesn't matter
109+
end
110+
# New rule should not have appeared
111+
@test collect(_rule_list(rrule)) == old_rrule_list
112+
# Above would error if we were not handling UnionAll's right
98113
end
99-
# New rule should not have appeared
100-
@test collect(_rule_list(rrule)) == old_rrule_list
101114
end
102115
end
103116
end

0 commit comments

Comments
 (0)