Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DifferentiationInterface"
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
authors = ["Guillaume Dalle", "Adrian Hill"]
version = "0.6.46"
version = "0.6.47"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
1 change: 1 addition & 0 deletions DifferentiationInterface/docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ DifferentiationInterface
Context
Constant
Cache
ConstantOrCache
```

## First order
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ force_annotation(f::F) where {F} = Const(f)
end

@inline function _translate(
backend::AutoEnzyme, ::Mode, ::Val{B}, c::Union{DI.Cache,DI.PrepContext}
backend::AutoEnzyme, ::Mode, ::Val{B}, c::Union{DI.Cache,DI.GeneralizedConstantOrCache}
) where {B}
if B == 1
return Duplicated(DI.unwrap(c), make_zero(DI.unwrap(c)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ function DI.prepare_pushforward_nokwarg(
strict::Val, f, backend::AutoFiniteDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C};
) where {C}
_sig = DI.signature(f, backend, x, tx, contexts...; strict)
fc = DI.with_contexts(f, contexts...)
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
y = fc(x)
cache = if x isa Number || y isa Number
nothing
Expand Down Expand Up @@ -89,7 +89,7 @@ function DI.pushforward(
) where {SIG,C}
DI.check_prep(f, prep, backend, x, tx, contexts...)
(; relstep, absstep, dir) = prep
fc = DI.with_contexts(f, contexts...)
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
ty = map(tx) do dx
finite_difference_jvp(fc, x, dx, prep.cache; relstep, absstep, dir)
end
Expand All @@ -106,7 +106,7 @@ function DI.value_and_pushforward(
) where {SIG,C}
DI.check_prep(f, prep, backend, x, tx, contexts...)
(; relstep, absstep, dir) = prep
fc = DI.with_contexts(f, contexts...)
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
y = fc(x)
ty = map(tx) do dx
finite_difference_jvp(fc, x, dx, prep.cache, y; relstep, absstep, dir)
Expand All @@ -128,7 +128,7 @@ function DI.prepare_derivative_nokwarg(
strict::Val, f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}
) where {C}
_sig = DI.signature(f, backend, x, contexts...; strict)
fc = DI.with_contexts(f, contexts...)
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
y = fc(x)
cache = if y isa Number
nothing
Expand Down Expand Up @@ -161,7 +161,7 @@ function DI.derivative(
) where {SIG,C}
DI.check_prep(f, prep, backend, x, contexts...)
(; relstep, absstep, dir) = prep
fc = DI.with_contexts(f, contexts...)
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
return finite_difference_derivative(fc, x, fdtype(backend); relstep, absstep, dir)
end

Expand All @@ -174,7 +174,7 @@ function DI.value_and_derivative(
) where {SIG,C}
DI.check_prep(f, prep, backend, x, contexts...)
(; relstep, absstep, dir) = prep
fc = DI.with_contexts(f, contexts...)
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
y = fc(x)
return (
y,
Expand All @@ -195,7 +195,7 @@ function DI.derivative(
) where {SIG,C}
DI.check_prep(f, prep, backend, x, contexts...)
(; relstep, absstep, dir) = prep
fc = DI.with_contexts(f, contexts...)
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
return finite_difference_gradient(fc, x, prep.cache; relstep, absstep, dir)
end

Expand All @@ -209,7 +209,7 @@ function DI.derivative!(
) where {SIG,C}
DI.check_prep(f, prep, backend, x, contexts...)
(; relstep, absstep, dir) = prep
fc = DI.with_contexts(f, contexts...)
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
return finite_difference_gradient!(der, fc, x, prep.cache; relstep, absstep, dir)
end

Expand All @@ -221,7 +221,7 @@ function DI.value_and_derivative(
contexts::Vararg{DI.Context,C},
) where {SIG,C}
DI.check_prep(f, prep, backend, x, contexts...)
fc = DI.with_contexts(f, contexts...)
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
(; relstep, absstep, dir) = prep
y = fc(x)
return (y, finite_difference_gradient(fc, x, prep.cache; relstep, absstep, dir))
Expand All @@ -237,7 +237,7 @@ function DI.value_and_derivative!(
) where {SIG,C}
DI.check_prep(f, prep, backend, x, contexts...)
(; relstep, absstep, dir) = prep
fc = DI.with_contexts(f, contexts...)
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
return (
fc(x), finite_difference_gradient!(der, fc, x, prep.cache; relstep, absstep, dir)
)
Expand All @@ -257,7 +257,7 @@ function DI.prepare_gradient_nokwarg(
strict::Val, f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}
) where {C}
_sig = DI.signature(f, backend, x, contexts...; strict)
fc = DI.with_contexts(f, contexts...)
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
y = fc(x)
df = zero(y) .* x
cache = GradientCache(df, x, fdtype(backend))
Expand All @@ -284,7 +284,7 @@ function DI.gradient(
) where {C}
DI.check_prep(f, prep, backend, x, contexts...)
(; relstep, absstep, dir) = prep
fc = DI.with_contexts(f, contexts...)
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
return finite_difference_gradient(fc, x, prep.cache; relstep, absstep, dir)
end

Expand All @@ -297,7 +297,7 @@ function DI.value_and_gradient(
) where {C}
DI.check_prep(f, prep, backend, x, contexts...)
(; relstep, absstep, dir) = prep
fc = DI.with_contexts(f, contexts...)
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
return fc(x), finite_difference_gradient(fc, x, prep.cache; relstep, absstep, dir)
end

Expand All @@ -311,7 +311,7 @@ function DI.gradient!(
) where {C}
DI.check_prep(f, prep, backend, x, contexts...)
(; relstep, absstep, dir) = prep
fc = DI.with_contexts(f, contexts...)
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
return finite_difference_gradient!(grad, fc, x, prep.cache; relstep, absstep, dir)
end

Expand All @@ -325,7 +325,7 @@ function DI.value_and_gradient!(
) where {C}
DI.check_prep(f, prep, backend, x, contexts...)
(; relstep, absstep, dir) = prep
fc = DI.with_contexts(f, contexts...)
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
return (
fc(x), finite_difference_gradient!(grad, fc, x, prep.cache; relstep, absstep, dir)
)
Expand All @@ -345,7 +345,7 @@ function DI.prepare_jacobian_nokwarg(
strict::Val, f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}
) where {C}
_sig = DI.signature(f, backend, x, contexts...; strict)
fc = DI.with_contexts(f, contexts...)
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
y = fc(x)
x1 = similar(x)
fx = similar(y)
Expand Down Expand Up @@ -374,7 +374,7 @@ function DI.jacobian(
) where {C}
DI.check_prep(f, prep, backend, x, contexts...)
(; relstep, absstep, dir) = prep
fc = DI.with_contexts(f, contexts...)
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
return finite_difference_jacobian(fc, x, prep.cache; relstep, absstep, dir)
end

Expand All @@ -386,7 +386,7 @@ function DI.value_and_jacobian(
contexts::Vararg{DI.Context,C},
) where {C}
DI.check_prep(f, prep, backend, x, contexts...)
fc = DI.with_contexts(f, contexts...)
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
(; relstep, absstep, dir) = prep
y = fc(x)
return (y, finite_difference_jacobian(fc, x, prep.cache, y; relstep, absstep, dir))
Expand All @@ -402,7 +402,7 @@ function DI.jacobian!(
) where {C}
DI.check_prep(f, prep, backend, x, contexts...)
(; relstep, absstep, dir) = prep
fc = DI.with_contexts(f, contexts...)
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
return copyto!(
jac,
finite_difference_jacobian(
Expand All @@ -421,7 +421,7 @@ function DI.value_and_jacobian!(
) where {C}
DI.check_prep(f, prep, backend, x, contexts...)
(; relstep, absstep, dir) = prep
fc = DI.with_contexts(f, contexts...)
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
y = fc(x)
return (
y,
Expand Down Expand Up @@ -450,7 +450,7 @@ function DI.prepare_hessian_nokwarg(
strict::Val, f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}
) where {C}
_sig = DI.signature(f, backend, x, contexts...; strict)
fc = DI.with_contexts(f, contexts...)
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
y = fc(x)
df = zero(y) .* x
gradient_cache = GradientCache(df, x, fdtype(backend))
Expand Down Expand Up @@ -481,7 +481,7 @@ function DI.hessian(
) where {C}
DI.check_prep(f, prep, backend, x, contexts...)
(; relstep_h, absstep_h) = prep
fc = DI.with_contexts(f, contexts...)
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
return finite_difference_hessian(
fc, x, prep.hessian_cache; relstep=relstep_h, absstep=absstep_h
)
Expand All @@ -497,7 +497,7 @@ function DI.hessian!(
) where {C}
DI.check_prep(f, prep, backend, x, contexts...)
(; relstep_h, absstep_h) = prep
fc = DI.with_contexts(f, contexts...)
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
return finite_difference_hessian!(
hess, fc, x, prep.hessian_cache; relstep=relstep_h, absstep=absstep_h
)
Expand All @@ -512,7 +512,7 @@ function DI.value_gradient_and_hessian(
) where {C}
DI.check_prep(f, prep, backend, x, contexts...)
(; relstep_g, absstep_g, relstep_h, absstep_h) = prep
fc = DI.with_contexts(f, contexts...)
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
grad = finite_difference_gradient(
fc, x, prep.gradient_cache; relstep=relstep_g, absstep=absstep_g
)
Expand All @@ -533,7 +533,7 @@ function DI.value_gradient_and_hessian!(
) where {C}
DI.check_prep(f, prep, backend, x, contexts...)
(; relstep_g, absstep_g, relstep_h, absstep_h) = prep
fc = DI.with_contexts(f, contexts...)
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
finite_difference_gradient!(
grad, fc, x, prep.gradient_cache; relstep=relstep_g, absstep=absstep_g
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ function DI.pushforward(
) where {SIG,C}
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
(; relstep, absstep, dir) = prep
fc! = DI.with_contexts(f!, contexts...)
fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...)
ty = map(tx) do dx
dy = similar(y)
finite_difference_jvp!(dy, fc!, x, dx, prep.cache; relstep, absstep, dir)
Expand All @@ -100,7 +100,7 @@ function DI.value_and_pushforward(
) where {SIG,C}
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
(; relstep, absstep, dir) = prep
fc! = DI.with_contexts(f!, contexts...)
fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...)
ty = map(tx) do dx
dy = similar(y)
finite_difference_jvp!(dy, fc!, x, dx, prep.cache; relstep, absstep, dir)
Expand All @@ -122,7 +122,7 @@ function DI.pushforward!(
) where {SIG,C}
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
(; relstep, absstep, dir) = prep
fc! = DI.with_contexts(f!, contexts...)
fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...)
for b in eachindex(tx, ty)
dx, dy = tx[b], ty[b]
finite_difference_jvp!(dy, fc!, x, dx, prep.cache; relstep, absstep, dir)
Expand All @@ -142,7 +142,7 @@ function DI.value_and_pushforward!(
) where {SIG,C}
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
(; relstep, absstep, dir) = prep
fc! = DI.with_contexts(f!, contexts...)
fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...)
for b in eachindex(tx, ty)
dx, dy = tx[b], ty[b]
finite_difference_jvp!(dy, fc!, x, dx, prep.cache; relstep, absstep, dir)
Expand Down Expand Up @@ -214,7 +214,7 @@ function DI.value_and_derivative(
) where {C}
DI.check_prep(f!, y, prep, backend, x, contexts...)
(; relstep, absstep, dir) = prep
fc! = DI.with_contexts(f!, contexts...)
fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...)
fc!(y, x)
der = finite_difference_gradient(fc!, x, prep.cache; relstep, absstep, dir)
return y, der
Expand All @@ -231,7 +231,7 @@ function DI.value_and_derivative!(
) where {C}
DI.check_prep(f!, y, prep, backend, x, contexts...)
(; relstep, absstep, dir) = prep
fc! = DI.with_contexts(f!, contexts...)
fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...)
fc!(y, x)
finite_difference_gradient!(der, fc!, x, prep.cache; relstep, absstep, dir)
return y, der
Expand All @@ -247,7 +247,7 @@ function DI.derivative(
) where {C}
DI.check_prep(f!, y, prep, backend, x, contexts...)
(; relstep, absstep, dir) = prep
fc! = DI.with_contexts(f!, contexts...)
fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...)
fc!(y, x)
der = finite_difference_gradient(fc!, x, prep.cache; relstep, absstep, dir)
return der
Expand All @@ -264,7 +264,7 @@ function DI.derivative!(
) where {C}
DI.check_prep(f!, y, prep, backend, x, contexts...)
(; relstep, absstep, dir) = prep
fc! = DI.with_contexts(f!, contexts...)
fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...)
finite_difference_gradient!(der, fc!, x, prep.cache; relstep, absstep, dir)
return der
end
Expand Down Expand Up @@ -336,7 +336,7 @@ function DI.value_and_jacobian(
) where {C}
DI.check_prep(f!, y, prep, backend, x, contexts...)
(; relstep, absstep, dir) = prep
fc! = DI.with_contexts(f!, contexts...)
fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...)
jac = similar(y, length(y), length(x))
finite_difference_jacobian!(jac, fc!, x, prep.cache; relstep, absstep, dir)
fc!(y, x)
Expand All @@ -354,7 +354,7 @@ function DI.value_and_jacobian!(
) where {C}
DI.check_prep(f!, y, prep, backend, x, contexts...)
(; relstep, absstep, dir) = prep
fc! = DI.with_contexts(f!, contexts...)
fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...)
finite_difference_jacobian!(jac, fc!, x, prep.cache; relstep, absstep, dir)
fc!(y, x)
return y, jac
Expand All @@ -370,7 +370,7 @@ function DI.jacobian(
) where {C}
DI.check_prep(f!, y, prep, backend, x, contexts...)
(; relstep, absstep, dir) = prep
fc! = DI.with_contexts(f!, contexts...)
fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...)
jac = similar(y, length(y), length(x))
finite_difference_jacobian!(jac, fc!, x, prep.cache; relstep, absstep, dir)
return jac
Expand All @@ -387,7 +387,7 @@ function DI.jacobian!(
) where {C}
DI.check_prep(f!, y, prep, backend, x, contexts...)
(; relstep, absstep, dir) = prep
fc! = DI.with_contexts(f!, contexts...)
fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...)
finite_difference_jacobian!(jac, fc!, x, prep.cache; relstep, absstep, dir)
return jac
end
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ function DI.pushforward(
contexts::Vararg{DI.Context,C},
) where {C}
DI.check_prep(f, prep, backend, x, tx, contexts...)
fc = DI.with_contexts(f, contexts...)
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
ty = map(tx) do dx
jvp(backend.fdm, fc, (x, dx))
end
Expand Down Expand Up @@ -75,7 +75,7 @@ function DI.pullback(
contexts::Vararg{DI.Context,C},
) where {C}
DI.check_prep(f, prep, backend, x, ty, contexts...)
fc = DI.with_contexts(f, contexts...)
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
tx = map(ty) do dy
only(j′vp(backend.fdm, fc, dy, x))
end
Expand Down Expand Up @@ -112,7 +112,7 @@ function DI.gradient(
contexts::Vararg{DI.Context,C},
) where {C}
DI.check_prep(f, prep, backend, x, contexts...)
fc = DI.with_contexts(f, contexts...)
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
return only(grad(backend.fdm, fc, x))
end

Expand Down Expand Up @@ -169,7 +169,7 @@ function DI.jacobian(
contexts::Vararg{DI.Context,C},
) where {C}
DI.check_prep(f, prep, backend, x, contexts...)
fc = DI.with_contexts(f, contexts...)
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
return only(jacobian(backend.fdm, fc, x))
end

Expand Down
Loading