Skip to content

Commit 2b056ef

Browse files
committed
Arrayify methods
1 parent 4fd5436 commit 2b056ef

1 file changed

Lines changed: 8 additions & 0 deletions

File tree

test/mooncake.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,21 @@ using VectorInterface
44
using VectorInterface: MinimalMVec, MinimalSVec, MinimalVec
55
using Test, TestExtras
66
using Mooncake
7+
import Mooncake: arrayify
78
using Random
89

910
rng = Random.default_rng()
1011

1112
precision(::Type{T}) where {T <: Union{Float32, ComplexF32}} = sqrt(eps(Float32))
1213
precision(::Type{T}) where {T <: Union{Float64, ComplexF64}} = sqrt(eps(Float64))
1314

15+
function Mooncake.arrayify(A_dA::Mooncake.CoDual{<:MinimalVec})
16+
return (Mooncake.primal(A_dA).vec, Mooncake.tangent(A_dA).data.vec)
17+
end
18+
function Mooncake.arrayify(A_dA::Mooncake.Dual{<:MinimalVec})
19+
return (Mooncake.primal(A_dA).vec, Mooncake.tangent(A_dA).fields.vec)
20+
end
21+
1422
eltypes = (Float32, Float64, ComplexF64)
1523

1624
@testset "scale ($T)" for T in eltypes

0 commit comments

Comments
 (0)