Skip to content

Commit f73a842

Browse files
authored
Merge pull request #4 from JuliaDiff/ox/config
Handle RuleConfig and move to new world
2 parents 9bf7935 + 03b8233 commit f73a842

File tree

7 files changed

+48
-18
lines changed

7 files changed

+48
-18
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
name = "ChainRulesOverloadGeneration"
22
uuid = "f51149dc-2911-5acf-81fc-2076a2a81d4f"
3-
version = "0.1.1"
3+
version = "0.1.2"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
77

88
[compat]
9-
ChainRulesCore = "0.9, 0.10"
9+
ChainRulesCore = "0.10.4"
1010
julia = "1"
1111

1212
[extras]

docs/Manifest.toml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,15 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
1111

1212
[[ChainRulesCore]]
1313
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
14-
git-tree-sha1 = "b391f22252b8754f4440de1f37ece49d8a7314bb"
14+
git-tree-sha1 = "d659e42240c2162300b321f05173cab5cc40a5ba"
1515
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
16-
version = "0.9.44"
16+
version = "0.10.4"
1717

1818
[[ChainRulesOverloadGeneration]]
1919
deps = ["ChainRulesCore"]
2020
path = ".."
2121
uuid = "f51149dc-2911-5acf-81fc-2076a2a81d4f"
22-
version = "0.1.0"
22+
version = "0.1.2"
2323

2424
[[Compat]]
2525
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
@@ -40,10 +40,10 @@ deps = ["Random", "Serialization", "Sockets"]
4040
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
4141

4242
[[DocStringExtensions]]
43-
deps = ["LibGit2", "Markdown", "Pkg", "Test"]
44-
git-tree-sha1 = "9d4f64f79012636741cf01133158a54b24924c32"
43+
deps = ["LibGit2"]
44+
git-tree-sha1 = "a32185f5428d3986f47c2ab78b1f216d5e6cc96f"
4545
uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
46-
version = "0.8.4"
46+
version = "0.8.5"
4747

4848
[[DocThemeIndigo]]
4949
deps = ["Sass"]

src/ruleset_loading.jl

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,21 @@ clear_new_rule_hooks!(rule_kind) = empty!(_hook_list(rule_kind))
5252
_rule_list(frule | rrule)
5353
5454
Returns a list of all the methods of the currently defined rules of the given kind.
55-
Excluding the fallback rule that returns `nothing` for every input.
55+
Excluding the fallback rule that returns `nothing` for every input;
56+
and excluding rules that require a particular `RuleConfig`.
5657
"""
57-
function _rule_list end
58-
_rule_list(rule_kind) = (m for m in methods(rule_kind) if !_is_fallback(rule_kind, m))
58+
function _rule_list(rule_kind)
59+
return Iterators.filter(methods(rule_kind)) do m
60+
return !_is_fallback(rule_kind, m) && !_requires_config(m)
61+
end
62+
end
5963

6064
"check if this is the fallback-frule/rrule that always returns `nothing`"
61-
_is_fallback(rule_kind, m::Method) = m.sig === Tuple{typeof(rule_kind), Any, Vararg{Any}}
65+
_is_fallback(::typeof(rrule), m::Method) = m.sig === Tuple{typeof(rrule),Any,Vararg{Any}}
66+
_is_fallback(::typeof(frule), m::Method) = m.sig === Tuple{typeof(frule),Any,Any,Vararg{Any}}
67+
68+
"check if this rule requires a particular configuation (`RuleConfig`)"
69+
_requires_config(m::Method) = m.sig.parameters[2] <: RuleConfig
6270

6371
const LAST_REFRESH_RRULE = Ref(0)
6472
const LAST_REFRESH_FRULE = Ref(0)

test/demos/forwarddiffzero.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ function define_dual_overload(sig)
4444
# we use the function call overloading form as it lets us avoid namespacing issues
4545
# as we can directly interpolate the function type into to the AST.
4646
function (op::$opT)(dual_args::Vararg{Union{Dual, Float64}, $N}; kwargs...)
47-
ȧrgs = (NO_FIELDS, partial.(dual_args)...)
47+
ȧrgs = (NoTangent(), partial.(dual_args)...)
4848
args = (op, primal.(dual_args)...)
4949
y, ẏ = frule(ȧrgs, args...; kwargs...)
5050
return Dual(y, ẏ) # if y, ẏ are not `Float64` this will error.

test/demos/reversediffzero.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ end
116116
function ChainRulesCore.rrule(::typeof(*), x::Number, y::Number)
117117
function times_pullback(ΔΩ)
118118
# we will use thunks here to show we handle them fine.
119-
return (NO_FIELDS, @thunk(ΔΩ * y'), @thunk(x' * ΔΩ))
119+
return (NoTangent(), @thunk(ΔΩ * y'), @thunk(x' * ΔΩ))
120120
end
121121
return x * y, times_pullback
122122
end

test/ruleset_loading.jl

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,31 @@
7373
@testset "_is_fallback" begin
7474
_is_fallback = ChainRulesOverloadGeneration._is_fallback
7575
@test _is_fallback(rrule, first(methods(rrule, (Nothing,))))
76-
@test _is_fallback(frule, first(methods(frule, (Nothing,))))
76+
@test _is_fallback(frule, first(methods(frule, (Tuple{}, Nothing,))))
77+
end
78+
79+
@testset "_rule_list" begin
80+
_rule_list = ChainRulesOverloadGeneration._rule_list
81+
@testset "should not have frules that need RuleConfig" begin
82+
old_frule_list = collect(_rule_list(frule))
83+
function ChainRulesCore.frule(
84+
::RuleConfig{>:Union{HasForwardsMode,HasReverseMode}}, dargs, sum, f, xs
85+
)
86+
return 1.0, 1.0 # this will not be call so return doesn't matter
87+
end
88+
# New rule should not have appeared
89+
@test collect(_rule_list(frule)) == old_frule_list
90+
end
91+
92+
@testset "should not have rrules that need RuleConfig" begin
93+
old_rrule_list = collect(_rule_list(rrule))
94+
function ChainRulesCore.rrule(
95+
::RuleConfig{>:Union{HasForwardsMode,HasReverseMode}}, sum, f, xs
96+
)
97+
return 1.0, x->(x,x,x) # this will not be call so return doesn't matter
98+
end
99+
# New rule should not have appeared
100+
@test collect(_rule_list(rrule)) == old_rrule_list
101+
end
77102
end
78103
end

test/runtests.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
using ChainRulesCore
22
using ChainRulesOverloadGeneration
3-
# resolve conflicts while this code exists in both.
4-
const on_new_rule = ChainRulesOverloadGeneration.on_new_rule
5-
const refresh_rules = ChainRulesOverloadGeneration.refresh_rules
63

74
using Test
85

0 commit comments

Comments
 (0)