Skip to content

Commit 4fd5436

Browse files
committed
Loosen type rules for Mooncake
1 parent 252809f commit 4fd5436

1 file changed

Lines changed: 12 additions & 12 deletions

File tree

ext/VectorInterfaceMooncakeExt.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ _needs_tangent(::Type{T}) where {T <: Number} =
2323

2424
# scale
2525
# -----
26-
@is_primitive DefaultCtx Tuple{typeof(scale!), AbstractArray, Number}
27-
function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual{<:AbstractArray}, α_Δα::CoDual{<:Number})
26+
@is_primitive DefaultCtx Tuple{typeof(scale!), Any, Number}
27+
function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual, α_Δα::CoDual{<:Number})
2828
# prepare arguments
2929
C, ΔC = arrayify(C_ΔC)
3030
α = primal(α_Δα)
@@ -43,7 +43,7 @@ function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual{<:AbstractArra
4343
return C_ΔC, scale_pullback
4444
end
4545

46-
function Mooncake.frule!!(::Dual{typeof(scale!)}, C_ΔC::Dual{<:AbstractArray}, α_Δα::Dual{<:Number})
46+
function Mooncake.frule!!(::Dual{typeof(scale!)}, C_ΔC::Dual, α_Δα::Dual{<:Number})
4747
# prepare arguments
4848
C, ΔC = arrayify(C_ΔC)
4949
α, Δα = extract(α_Δα)
@@ -58,9 +58,9 @@ function Mooncake.frule!!(::Dual{typeof(scale!)}, C_ΔC::Dual{<:AbstractArray},
5858
return C_ΔC
5959
end
6060

61-
@is_primitive DefaultCtx Tuple{typeof(scale!), AbstractArray, AbstractArray, Number}
61+
@is_primitive DefaultCtx Tuple{typeof(scale!), Any, Any, Number}
6262

63-
function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual{<:AbstractArray}, A_ΔA::CoDual{<:AbstractArray}, α_Δα::CoDual{<:Number})
63+
function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual, A_ΔA::CoDual, α_Δα::CoDual{<:Number})
6464
# prepare arguments
6565
C, ΔC = arrayify(C_ΔC)
6666
A, ΔA = arrayify(A_ΔA)
@@ -81,7 +81,7 @@ function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual{<:AbstractArra
8181
return C_ΔC, scale_pullback
8282
end
8383

84-
function Mooncake.frule!!(::Dual{typeof(scale!)}, C_ΔC::Dual{<:AbstractArray}, A_ΔA::Dual{<:AbstractArray}, α_Δα::Dual{<:Number})
84+
function Mooncake.frule!!(::Dual{typeof(scale!)}, C_ΔC::Dual, A_ΔA::Dual, α_Δα::Dual{<:Number})
8585
# prepare arguments
8686
C, ΔC = arrayify(C_ΔC)
8787
A, ΔA = arrayify(A_ΔA)
@@ -96,9 +96,9 @@ end
9696
# add
9797
# ---
9898

99-
@is_primitive DefaultCtx Tuple{typeof(add!), AbstractArray, AbstractArray, Number, Number}
99+
@is_primitive DefaultCtx Tuple{typeof(add!), Any, Any, Number, Number}
100100

101-
function Mooncake.rrule!!(::CoDual{typeof(add!)}, C_ΔC::CoDual{<:AbstractArray}, A_ΔA::CoDual{<:AbstractArray}, α_Δα::CoDual{<:Number}, β_Δβ::CoDual{<:Number})
101+
function Mooncake.rrule!!(::CoDual{typeof(add!)}, C_ΔC::CoDual, A_ΔA::CoDual, α_Δα::CoDual{<:Number}, β_Δβ::CoDual{<:Number})
102102
# prepare arguments
103103
C, ΔC = arrayify(C_ΔC)
104104
A, ΔA = arrayify(A_ΔA)
@@ -123,7 +123,7 @@ function Mooncake.rrule!!(::CoDual{typeof(add!)}, C_ΔC::CoDual{<:AbstractArray}
123123
return C_ΔC, add_pullback
124124
end
125125

126-
function Mooncake.frule!!(::Dual{typeof(add!)}, C_ΔC::Dual{<:AbstractArray}, A_ΔA::Dual{<:AbstractArray}, α_Δα::Dual{<:Number}, β_Δβ::Dual{<:Number})
126+
function Mooncake.frule!!(::Dual{typeof(add!)}, C_ΔC::Dual, A_ΔA::Dual, α_Δα::Dual{<:Number}, β_Δβ::Dual{<:Number})
127127
# prepare arguments
128128
C, ΔC = arrayify(C_ΔC)
129129
A, ΔA = arrayify(A_ΔA)
@@ -140,9 +140,9 @@ end
140140
# inner
141141
# -----
142142

143-
@is_primitive DefaultCtx Tuple{typeof(inner), AbstractArray, AbstractArray}
143+
@is_primitive DefaultCtx Tuple{typeof(inner), Any, Any}
144144

145-
function Mooncake.rrule!!(::CoDual{typeof(inner)}, A_ΔA::CoDual{<:AbstractArray}, B_ΔB::CoDual{<:AbstractArray})
145+
function Mooncake.rrule!!(::CoDual{typeof(inner)}, A_ΔA::CoDual, B_ΔB::CoDual)
146146
# prepare arguments
147147
A, ΔA = arrayify(A_ΔA)
148148
B, ΔB = arrayify(B_ΔB)
@@ -159,7 +159,7 @@ function Mooncake.rrule!!(::CoDual{typeof(inner)}, A_ΔA::CoDual{<:AbstractArray
159159
return CoDual(s, NoFData()), inner_pullback
160160
end
161161

162-
function Mooncake.frule!!(::Dual{typeof(inner)}, A_ΔA::Dual{<:AbstractArray}, B_ΔB::Dual{<:AbstractArray})
162+
function Mooncake.frule!!(::Dual{typeof(inner)}, A_ΔA::Dual, B_ΔB::Dual)
163163
# prepare arguments
164164
A, ΔA = arrayify(A_ΔA)
165165
B, ΔB = arrayify(B_ΔB)

0 commit comments

Comments
 (0)