Skip to content

Commit cf73f95

Browse files
committed
Coverage and docs
1 parent b0e8697 commit cf73f95

File tree

2 files changed

+41
-1
lines changed

2 files changed

+41
-1
lines changed

DifferentiationInterface/src/utils/context.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,20 +112,46 @@ adapt_eltype(c::ConstantOrCache, ::Type) = c
112112

113113
## Internal contexts for passing stuff around
114114

115+
"""
116+
FunctionContext
117+
118+
Private type of [`Context`](@ref) argument used for passing functions inside second-order differentiation.
119+
120+
Behaves differently for Enzyme only, where the function can be annotated.
121+
"""
115122
struct FunctionContext{T} <: GeneralizedConstant
116123
data::T
117124
end
118125

126+
"""
127+
BackendContext
128+
129+
Private type of [`Context`](@ref) argument used for passing backends inside second-order differentiation.
130+
"""
119131
struct BackendContext{T} <: GeneralizedConstant
120132
data::T
121133
end
122134

135+
"""
136+
PrepContext
137+
138+
Private type of [`Context`](@ref) argument used for passing preparation results inside second-order differentiation.
139+
140+
Conceptually similar to [`ConstantOrCache`](@ref) because we assume that preparation was performed with the right types so we don't change anything.
141+
"""
123142
struct PrepContext{T} <: GeneralizedConstantOrCache
124143
data::T
125144
end
126145

127146
## Context manipulation
128147

148+
"""
149+
Rewrap
150+
151+
Utility for recording context types of additional arguments (e.g. `Constant` or `Cache`) and re-wrapping them into their types after they have been unwrapped.
152+
153+
Useful for second-order differentiation.
154+
"""
129155
struct Rewrap{C,T}
130156
context_makers::T
131157
function Rewrap(contexts::Vararg{Context,C}) where {C}
@@ -144,6 +170,14 @@ end
144170

145171
## Closures
146172

173+
"""
174+
FixTail
175+
176+
Closure around a function `f` and a set of tail argument `tail_args` such that
177+
```
178+
(ft::FixTail)(args...) = ft.f(args..., ft.tail_args...)
179+
```
180+
"""
147181
struct FixTail{F,A<:Tuple}
148182
f::F
149183
tail_args::A
@@ -156,5 +190,10 @@ function (ft::FixTail)(args::Vararg{Any,N}) where {N}
156190
return ft.f(args..., ft.tail_args...)
157191
end
158192

193+
"""
194+
fix_tail(f, tail_args...)
195+
196+
Convenience for constructing a [`FixTail`](@ref), with a shortcut when there are no tail arguments.
197+
"""
159198
@inline fix_tail(f::F) where {F} = f
160199
fix_tail(f::F, args::Vararg{Any,N}) where {F,N} = FixTail(f, args...)

DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@ end
6767
)
6868

6969
test_differentiation(
70-
second_order_hvp_backends;
70+
second_order_hvp_backends,
71+
default_scenarios(; include_constantorcachified=true);
7172
excluded=vcat(FIRST_ORDER, :hessian, :second_derivative),
7273
logging=LOGGING,
7374
)

0 commit comments

Comments
 (0)