Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@ authors = ["Valentin Churavy <v.churavy@gmail.com>"]
version = "0.1.0"

[deps]
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[compat]
DifferentiationInterface = "0.7.14"
Enzyme = "0.13.50"
Krylov = "0.10.1"
LinearAlgebra = "1.10"
Expand Down
198 changes: 11 additions & 187 deletions src/Ariadne.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,204 +4,28 @@ export newton_krylov, newton_krylov!

using Krylov
using LinearAlgebra, SparseArrays
using Enzyme

##
# JacobianOperator
##
import LinearAlgebra: mul!

function init_cache(x)
if !Enzyme.Compiler.guaranteed_const(typeof(x))
Enzyme.make_zero(x)
else
return nothing
end
end

function maybe_duplicated(x::T, x′::Union{Nothing, T}) where {T}
if x′ === nothing
return Const(x)
else
Enzyme.remake_zero!(x′)
return Duplicated(x, x′)
end
end

abstract type AbstractJacobianOperator end

# Interface:
# Base.size(J::AbstractJacobianOperator)
# Base.eltype(J::AbstractJacobianOperator)
# Base.length(J::AbstractJacobianOperator)
# mul!(out, J::AbstractJacobianOperator, v)
# LinearAlgebra.adjoint(J::AbstractJacobianOperator)
# LinearAlgebra.transpose(J::AbstractJacobianOperator)
# mul!(out, J′::Union{Adjoint{<:Any, <:AbstractJacobianOperator}, Transpose{<:Any, <:AbstractJacobianOperator}}, v)

"""
JacobianOperator

Efficient implementation of `J(f,x,p) * v` and `v * J(f, x,p)'`
"""
struct JacobianOperator{F, F′, A, P, P′} <: AbstractJacobianOperator
f::F # F!(res, u, p)
f′::F′ # cache
res::A
u::A
p::P
p′::P′ # cache
end

"""
JacobianOperator(f::F, res, u, p; assume_p_const::Bool = false)

Creates a Jacobian operator for `f!(res, u, p)` where `res` is the residual,
`u` is the state variable, and `p` are the parameters.

If `assume_p_const` is `true`, the parameters `p` are assumed to be constant
during the Jacobian computation, which can improve performance by not requiring the
shadow for `p`.
"""
function JacobianOperator(f::F, res, u, p; assume_p_const::Bool = false) where {F}
f′ = init_cache(f)
if assume_p_const
p′ = nothing
else
p′ = init_cache(p)
end
return JacobianOperator(f, f′, res, u, p, p′)
end

batch_size(::JacobianOperator) = 1

Base.size(J::JacobianOperator) = (length(J.res), length(J.u))
Base.eltype(J::JacobianOperator) = eltype(J.u)
Base.length(J::JacobianOperator) = prod(size(J))

function mul!(out, J::JacobianOperator, v)
autodiff(
Forward,
maybe_duplicated(J.f, J.f′), Const,
Duplicated(J.res, reshape(out, size(J.res))),
Duplicated(J.u, reshape(v, size(J.u))),
maybe_duplicated(J.p, J.p′)
)
return nothing
end

LinearAlgebra.adjoint(J::JacobianOperator) = Adjoint(J)
LinearAlgebra.transpose(J::JacobianOperator) = Transpose(J)

# Jᵀ(y, u) = ForwardDiff.gradient!(y, x -> dot(F(x), u), xk)
# or just reverse mode

function mul!(out, J′::Union{Adjoint{<:Any, <:JacobianOperator}, Transpose{<:Any, <:JacobianOperator}}, v)
J = parent(J′)
# TODO: provide cache for `copy(v)`
# Enzyme zeros input derivatives and that confuses the solvers.
# If `out` is non-zero we might get spurious gradients
fill!(out, 0)
autodiff(
Reverse,
maybe_duplicated(J.f, J.f′), Const,
Duplicated(J.res, reshape(copy(v), size(J.res))),
Duplicated(J.u, reshape(out, size(J.u))),
maybe_duplicated(J.p, J.p′)
)
return nothing
end


function init_cache(x, ::Val{N}) where {N}
if !Enzyme.Compiler.guaranteed_const(typeof(x))
return ntuple(_ -> Enzyme.make_zero(x), Val(N))
else
return nothing
end
end

function maybe_duplicated(x::T, x′::Union{Nothing, NTuple{N, T}}, ::Val{N}) where {T, N}
if x′ === nothing
return Const(x)
else
Enzyme.remake_zero!(x′)
return BatchDuplicated(x, x′)
end
end

"""
BatchedJacobianOperator{N}


"""
struct BatchedJacobianOperator{N, F, A, P} <: AbstractJacobianOperator
f::F # F!(res, u, p)
f′::Union{Nothing, NTuple{N, F}} # cache
res::A
u::A
p::P
p′::Union{Nothing, NTuple{N, P}} # cache
function BatchedJacobianOperator{N}(f::F, res, u, p) where {F, N}
f′ = init_cache(f, Val(N))
p′ = init_cache(p, Val(N))
return new{N, F, typeof(u), typeof(p)}(f, f′, res, u, p, p′)
end
end

batch_size(::BatchedJacobianOperator{N}) where {N} = N
include("operators/enzyme.jl")
include("operators/di.jl")

Base.size(J::BatchedJacobianOperator) = (length(J.res), length(J.u))
Base.eltype(J::BatchedJacobianOperator) = eltype(J.u)
Base.length(J::BatchedJacobianOperator) = prod(size(J))

LinearAlgebra.adjoint(J::BatchedJacobianOperator) = Adjoint(J)
LinearAlgebra.transpose(J::BatchedJacobianOperator) = Transpose(J)

if VERSION >= v"1.11.0"

function tuple_of_vectors(M::Matrix{T}, shape) where {T}
n, m = size(M)
return ntuple(m) do i
vec = Base.wrap(Array, memoryref(M.ref, (i - 1) * n + 1), (n,))
reshape(vec, shape)
end
end

function mul!(Out, J::BatchedJacobianOperator{N}, V) where {N}
@assert size(Out, 2) == size(V, 2)
out = tuple_of_vectors(Out, size(J.res))
v = tuple_of_vectors(V, size(J.u))

@assert N == length(out)
autodiff(
Forward,
maybe_duplicated(J.f, J.f′, Val(N)), Const,
BatchDuplicated(J.res, out),
BatchDuplicated(J.u, v),
maybe_duplicated(J.p, J.p′, Val(N))
)
return nothing
end

function mul!(Out, J′::Union{Adjoint{<:Any, <:BatchedJacobianOperator{N}}, Transpose{<:Any, <:BatchedJacobianOperator{N}}}, V) where {N}
J = parent(J′)
@assert size(Out, 2) == size(V, 2)

# If `out` is non-zero we might get spurious gradients
fill!(Out, 0)

# TODO: provide cache for `copy(v)`
# Enzyme zeros input derivatives and that confuses the solvers.
V = copy(V)

out = tuple_of_vectors(Out, size(J.u))
v = tuple_of_vectors(V, size(J.res))

@assert N == length(out)

autodiff(
Reverse,
maybe_duplicated(J.f, J.f′, Val(N)), Const,
BatchDuplicated(J.res, v),
BatchDuplicated(J.u, out),
maybe_duplicated(J.p, J.p′, Val(N))
)
return nothing
end
end # VERSION >= v"1.11.0"
const JacobianOperator = EnzymeJacobianOperator

function Base.collect(JOp::Union{Adjoint{<:Any, <:AbstractJacobianOperator}, Transpose{<:Any, <:AbstractJacobianOperator}, AbstractJacobianOperator})
N, M = size(JOp)
Expand Down
44 changes: 44 additions & 0 deletions src/operators/di.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import DifferentiationInterface as DI

"""
DIJacobianOperator
"""
struct DIJacobianOperator{F, A, P} <: AbstractJacobianOperator
f::F # F!(res, u, p)
res::A
u::A
p::P
prep
backend
end

"""
DIJacobianOperator(f::F, res, u, p)

Creates a Jacobian operator for `f!(res, u, p)` where `res` is the residual,
`u` is the state variable, and `p` are the parameters.
"""
function DIJacobianOperator(backend, f::F, res, u, p) where {F}
tu = zero(u) # dummy tangent
prep = DI.prepare_pushforward(f, res, backend, u, (tu,), DI.ConstantOrCache(p))

return DIJacobianOperator(f, res, u, p, prep, backend)
end

Base.size(J::DIJacobianOperator) = (length(J.res), length(J.u))
Base.eltype(J::DIJacobianOperator) = eltype(J.u)
Base.length(J::DIJacobianOperator) = prod(size(J))

function mul!(out, J::DIJacobianOperator, v)
DI.pushforward!(
J.f,
J.res,
(out,),
J.prep,
J.backend,
J.u,
(v,), # TODO: Must we zero this?
DI.ConstantOrCache(J.p)
)
return nothing
end
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Ariadne = "0be81120-40bf-4f8b-adf0-26103efb66f1"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
27 changes: 13 additions & 14 deletions test/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@ let x₀ = [3.0, 5.0]
@test stats.solved
end

import Ariadne: JacobianOperator, BatchedJacobianOperator
import Ariadne: JacobianOperator
using Enzyme, LinearAlgebra
using ADTypes

@testset "Jacobian" begin
@testset "Enzyme: JacobianOperator" begin
J_Enz = jacobian(Forward, x -> F(x, nothing), [3.0, 5.0]) |> only
J = JacobianOperator(F!, zeros(2), [3.0, 5.0], nothing)

Expand All @@ -52,19 +53,17 @@ using Enzyme, LinearAlgebra
@test out ≈ J_Enz * v

@test collect(transpose(J)) == transpose(collect(J))
end

# Batched
if VERSION >= v"1.11.0"
J = BatchedJacobianOperator{2}(F!, zeros(2), [3.0, 5.0], nothing)

V = [1.0 0.0; 0.0 1.0]
Out = similar(V)
mul!(Out, J, V)
@testset "DifferentiationInterface: JacobianOperator" begin
backend = ADTypes.AutoEnzyme()
J = Ariadne.DIJacobianOperator(backend, F!, zeros(2), [3.0, 5.0], nothing)

@test Out == J_Enz
@test size(J) == (2, 2)
@test length(J) == 4
@test eltype(J) == Float64

mul!(Out, transpose(J), V)
@test Out == J_Enz'
# @test Out == collect(transpose(J))
end
out = [NaN, NaN]
mul!(out, J, [1.0, 0.0])
@test out == [6.0, 7.38905609893065]
end
Loading