Skip to content

Commit f8bc27b

Browse files
authored
Loosen type rules for Mooncake (#38)
* Loosen type rules for Mooncake * Arrayify methods * Tighten isprimitive for AbstractArray
1 parent 252809f commit f8bc27b

2 files changed

Lines changed: 16 additions & 8 deletions

File tree

ext/VectorInterfaceMooncakeExt.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ _needs_tangent(::Type{T}) where {T <: Number} =
2424
# scale
2525
# -----
2626
@is_primitive DefaultCtx Tuple{typeof(scale!), AbstractArray, Number}
27-
function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual{<:AbstractArray}, α_Δα::CoDual{<:Number})
27+
function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual, α_Δα::CoDual{<:Number})
2828
# prepare arguments
2929
C, ΔC = arrayify(C_ΔC)
3030
α = primal(α_Δα)
@@ -43,7 +43,7 @@ function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual{<:AbstractArra
4343
return C_ΔC, scale_pullback
4444
end
4545

46-
function Mooncake.frule!!(::Dual{typeof(scale!)}, C_ΔC::Dual{<:AbstractArray}, α_Δα::Dual{<:Number})
46+
function Mooncake.frule!!(::Dual{typeof(scale!)}, C_ΔC::Dual, α_Δα::Dual{<:Number})
4747
# prepare arguments
4848
C, ΔC = arrayify(C_ΔC)
4949
α, Δα = extract(α_Δα)
@@ -60,7 +60,7 @@ end
6060

6161
@is_primitive DefaultCtx Tuple{typeof(scale!), AbstractArray, AbstractArray, Number}
6262

63-
function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual{<:AbstractArray}, A_ΔA::CoDual{<:AbstractArray}, α_Δα::CoDual{<:Number})
63+
function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual, A_ΔA::CoDual, α_Δα::CoDual{<:Number})
6464
# prepare arguments
6565
C, ΔC = arrayify(C_ΔC)
6666
A, ΔA = arrayify(A_ΔA)
@@ -81,7 +81,7 @@ function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual{<:AbstractArra
8181
return C_ΔC, scale_pullback
8282
end
8383

84-
function Mooncake.frule!!(::Dual{typeof(scale!)}, C_ΔC::Dual{<:AbstractArray}, A_ΔA::Dual{<:AbstractArray}, α_Δα::Dual{<:Number})
84+
function Mooncake.frule!!(::Dual{typeof(scale!)}, C_ΔC::Dual, A_ΔA::Dual, α_Δα::Dual{<:Number})
8585
# prepare arguments
8686
C, ΔC = arrayify(C_ΔC)
8787
A, ΔA = arrayify(A_ΔA)
@@ -98,7 +98,7 @@ end
9898

9999
@is_primitive DefaultCtx Tuple{typeof(add!), AbstractArray, AbstractArray, Number, Number}
100100

101-
function Mooncake.rrule!!(::CoDual{typeof(add!)}, C_ΔC::CoDual{<:AbstractArray}, A_ΔA::CoDual{<:AbstractArray}, α_Δα::CoDual{<:Number}, β_Δβ::CoDual{<:Number})
101+
function Mooncake.rrule!!(::CoDual{typeof(add!)}, C_ΔC::CoDual, A_ΔA::CoDual, α_Δα::CoDual{<:Number}, β_Δβ::CoDual{<:Number})
102102
# prepare arguments
103103
C, ΔC = arrayify(C_ΔC)
104104
A, ΔA = arrayify(A_ΔA)
@@ -123,7 +123,7 @@ function Mooncake.rrule!!(::CoDual{typeof(add!)}, C_ΔC::CoDual{<:AbstractArray}
123123
return C_ΔC, add_pullback
124124
end
125125

126-
function Mooncake.frule!!(::Dual{typeof(add!)}, C_ΔC::Dual{<:AbstractArray}, A_ΔA::Dual{<:AbstractArray}, α_Δα::Dual{<:Number}, β_Δβ::Dual{<:Number})
126+
function Mooncake.frule!!(::Dual{typeof(add!)}, C_ΔC::Dual, A_ΔA::Dual, α_Δα::Dual{<:Number}, β_Δβ::Dual{<:Number})
127127
# prepare arguments
128128
C, ΔC = arrayify(C_ΔC)
129129
A, ΔA = arrayify(A_ΔA)
@@ -142,7 +142,7 @@ end
142142

143143
@is_primitive DefaultCtx Tuple{typeof(inner), AbstractArray, AbstractArray}
144144

145-
function Mooncake.rrule!!(::CoDual{typeof(inner)}, A_ΔA::CoDual{<:AbstractArray}, B_ΔB::CoDual{<:AbstractArray})
145+
function Mooncake.rrule!!(::CoDual{typeof(inner)}, A_ΔA::CoDual, B_ΔB::CoDual)
146146
# prepare arguments
147147
A, ΔA = arrayify(A_ΔA)
148148
B, ΔB = arrayify(B_ΔB)
@@ -159,7 +159,7 @@ function Mooncake.rrule!!(::CoDual{typeof(inner)}, A_ΔA::CoDual{<:AbstractArray
159159
return CoDual(s, NoFData()), inner_pullback
160160
end
161161

162-
function Mooncake.frule!!(::Dual{typeof(inner)}, A_ΔA::Dual{<:AbstractArray}, B_ΔB::Dual{<:AbstractArray})
162+
function Mooncake.frule!!(::Dual{typeof(inner)}, A_ΔA::Dual, B_ΔB::Dual)
163163
# prepare arguments
164164
A, ΔA = arrayify(A_ΔA)
165165
B, ΔB = arrayify(B_ΔB)

test/mooncake.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,21 @@ using VectorInterface
44
using VectorInterface: MinimalMVec, MinimalSVec, MinimalVec
55
using Test, TestExtras
66
using Mooncake
7+
import Mooncake: arrayify
78
using Random
89

910
rng = Random.default_rng()
1011

1112
precision(::Type{T}) where {T <: Union{Float32, ComplexF32}} = sqrt(eps(Float32))
1213
precision(::Type{T}) where {T <: Union{Float64, ComplexF64}} = sqrt(eps(Float64))
1314

15+
function Mooncake.arrayify(A_dA::Mooncake.CoDual{<:MinimalVec})
16+
return (Mooncake.primal(A_dA).vec, Mooncake.tangent(A_dA).data.vec)
17+
end
18+
function Mooncake.arrayify(A_dA::Mooncake.Dual{<:MinimalVec})
19+
return (Mooncake.primal(A_dA).vec, Mooncake.tangent(A_dA).fields.vec)
20+
end
21+
1422
eltypes = (Float32, Float64, ComplexF64)
1523

1624
@testset "scale ($T)" for T in eltypes

0 commit comments

Comments
 (0)