Skip to content

Commit 557b400

Browse files
authored
Merge pull request #8 from JuliaDiff/ox/optout
Don't return opted out rules. CRC 1.0 compat
2 parents 7033b82 + 4aaa97f commit 557b400

File tree

4 files changed

+52
-18
lines changed

4 files changed

+52
-18
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
name = "ChainRulesOverloadGeneration"
22
uuid = "f51149dc-2911-5acf-81fc-2076a2a81d4f"
3-
version = "0.1.3"
3+
version = "0.1.4"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
77

88
[compat]
9-
ChainRulesCore = "0.10.4"
9+
ChainRulesCore = "1.0.0"
1010
julia = "1"
1111

1212
[extras]

src/ruleset_loading.jl

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,27 +48,38 @@ If you previously wrong an incorrect hook, you can use this to get rid of the ol
4848
"""
4949
clear_new_rule_hooks!(rule_kind) = empty!(_hook_list(rule_kind))
5050

51+
###########################################################################################
52+
5153
"""
5254
_rule_list(frule | rrule)
5355
5456
Returns a list of all the methods of the currently defined rules of the given kind.
55-
Excluding the fallback rule that returns `nothing` for every input;
56-
and excluding rules that require a particular `RuleConfig`.
57+
Excluding the fallback rule (that return `nothing` for every input) and `@opt_out` opted out
58+
rules, and excluding rules that require a particular `RuleConfig`.
5759
"""
58-
function _rule_list(rule_kind)
60+
function _rule_list(rule_kind::Union{typeof(frule), typeof(rrule)})
61+
opted_out = Set(arg_type_tuple(m.sig) for m in _no_rule_list(rule_kind))
5962
return Iterators.filter(methods(rule_kind)) do m
60-
return !_is_fallback(rule_kind, m) && !_requires_config(m)
63+
return !_requires_config(m) && arg_type_tuple(m.sig) opted_out
6164
end
6265
end
6366

64-
"check if this is the fallback-frule/rrule that always returns `nothing`"
65-
_is_fallback(::typeof(rrule), m::Method) = m.sig === Tuple{typeof(rrule),Any,Vararg{Any}}
66-
_is_fallback(::typeof(frule), m::Method) = m.sig === Tuple{typeof(frule),Any,Any,Vararg{Any}}
67-
6867
"check if this rule requires a particular configuation (`RuleConfig`)"
6968
_requires_config(m::Method) = m.sig <: Tuple{Any, RuleConfig, Vararg}
7069

7170

71+
_no_rule_list(::typeof(rrule)) = methods(ChainRulesCore.no_rrule)
72+
_no_rule_list(::typeof(frule)) = methods(ChainRulesCore.no_frule)
73+
74+
arg_type_tuple(d::DataType) = Tuple{d.parameters[2:end]...}
75+
function arg_type_tuple(d::UnionAll)
76+
body = Base.unwrap_unionall(d)
77+
body_tt = arg_type_tuple(body)
78+
return Base.rewrap_unionall(body_tt, d)
79+
end
80+
81+
######################################################################
82+
7283
const LAST_REFRESH_RRULE = Ref(0)
7384
const LAST_REFRESH_FRULE = Ref(0)
7485
last_refresh(::typeof(frule)) = LAST_REFRESH_FRULE

test/ruleset_loading.jl

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,6 @@
7070
end
7171
end
7272

73-
@testset "_is_fallback" begin
74-
_is_fallback = ChainRulesOverloadGeneration._is_fallback
75-
@test _is_fallback(rrule, first(methods(rrule, (Nothing,))))
76-
@test _is_fallback(frule, first(methods(frule, (Tuple{}, Nothing,))))
77-
end
78-
7973
@testset "_rule_list" begin
8074
_rule_list = ChainRulesOverloadGeneration._rule_list
8175
@testset "should not have frules that need RuleConfig" begin
@@ -112,5 +106,32 @@
112106
# Above would error if we were not handling UnionAll's right
113107
end
114108
end
109+
110+
111+
@testset "opting out" begin
112+
oa_id(x, y) = x
113+
@scalar_rule(oa_id(x::Number), 1)
114+
@opt_out ChainRulesCore.rrule(::typeof(oa_id), x::Float32)
115+
@opt_out ChainRulesCore.frule(::Any, ::typeof(oa_id), x::Float32)
116+
117+
# In theses tests we `@assert` the behavour that `methods` has
118+
# and then `@test` that `_rule_list` differs from that, in the way we want
119+
120+
@test !isempty([m for m in _rule_list(rrule) if m.sig <: Tuple{Any,typeof(oa_id),Number}])
121+
# Opted out
122+
@assert !isempty([m for m in methods(rrule) if m.sig <: Tuple{Any,typeof(oa_id),Float32}])
123+
@test isempty([m for m in _rule_list(rrule) if m.sig <: Tuple{Any,typeof(oa_id),Float32}])
124+
# fallback
125+
@test !isempty([m for m in methods(rrule) if m.sig == Tuple{typeof(rrule),Any,Vararg{Any}}])
126+
@test isempty([m for m in _rule_list(rrule) if m.sig == Tuple{typeof(rrule),Any,Vararg{Any}}])
127+
128+
@test !isempty([m for m in _rule_list(frule) if m.sig <: Tuple{Any,Any,typeof(oa_id),Number}])
129+
# Opted out
130+
@assert !isempty([m for m in methods(frule) if m.sig <: Tuple{Any,Any,typeof(oa_id),Float32}])
131+
@test isempty([m for m in _rule_list(frule) if m.sig <: Tuple{Any,Any,typeof(oa_id),Float32}])
132+
# fallback
133+
@assert !isempty([m for m in methods(frule) if m.sig == Tuple{typeof(frule),Any,Any,Vararg{Any}}])
134+
@test isempty([m for m in _rule_list(frule) if m.sig == Tuple{typeof(frule),Any,Any,Vararg{Any}}])
135+
end
115136
end
116137
end

test/runtests.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@ using ChainRulesOverloadGeneration
44
using Test
55

66
@testset "ChainRulesCore" begin
7-
include("ruleset_loading.jl")
8-
97
@testset "demos" begin
108
include("demos/forwarddiffzero.jl")
119
include("demos/reversediffzero.jl")
1210
end
11+
12+
# Do this after demos run, so that the simple demo code doesn't have to handle
13+
# anything weird we define for testing purposes
14+
include("ruleset_loading.jl")
1315
end

0 commit comments

Comments
 (0)