Skip to content

Commit 1853e69

Browse files
authored
Custom rule docs (#1796)
* Custom rule docs * fix * fix * more fix * fix * rules * fix
1 parent 949d1ad commit 1853e69

File tree

3 files changed

+141
-4
lines changed

3 files changed

+141
-4
lines changed

docs/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
[deps]
22
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
3+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
34
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
45
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
56
EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a"

docs/src/index.md

Lines changed: 139 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,9 @@ julia> dx_1 = [1.0, 0.0]; dx_2 = [0.0, 1.0];
145145
julia> autodiff(ForwardWithPrimal, rosenbrock_inp, BatchDuplicated(x, (dx_1, dx_2)))
146146
((var"1" = -800.0, var"2" = 400.0), 400.0)
147147
```
148+
## Convenience functions (gradient, jacobian, hessian)
148149

149-
## Gradient Convenience functions
150+
### Gradient Convenience functions
150151

151152
!!! note
152153
While the convenience functions discussed below use [`autodiff`](@ref) internally, they are generally more limited in their functionality. Beyond that, these convenience functions may also come with performance penalties; especially if one makes a closure of a multi-argument function instead of calling the appropriate multi-argument [`autodiff`](@ref) function directly.
@@ -189,7 +190,7 @@ julia> # in forward mode, we can also optionally pass a chunk size
189190
([-400.0, 200.0],)
190191
```
191192

192-
## Jacobian Convenience functions
193+
### Jacobian Convenience functions
193194

194195
The function [`jacobian`](@ref) computes the Jacobian of a function vector input and vector return.
195196
Like [`autodiff`](@ref) and [`gradient`](@ref), the mode (forward or reverse) is determined by the first argument.
@@ -221,7 +222,7 @@ julia> jacobian(Forward, foo, [1.0, 2.0], chunk=Val(2))
221222
([-400.0 200.0; 2.0 1.0],)
222223
```
223224

224-
## Hessian Vector Product Convenience functions
225+
### Hessian Vector Product Convenience functions
225226

226227
Enzyme provides convenience functions for second-order derivative computations, like [`hvp`](@ref) to compute Hessian vector products. Mathematically, this computes $H(x) v$, where $H$ is the hessian operator.
227228

@@ -273,3 +274,138 @@ julia> grad
273274
2.880510859951098
274275
1.920340573300732
275276
```
277+
278+
## Defining rules
279+
280+
While Enzyme will automatically generate derivative functions for you, there may be instances in which it is necessary or helpful to define custom derivative rules. Enzyme has three primary ways for defining derivative rules: inactive annotations, [`EnzymeRules.@easy_rule`](@ref) macro definitions, general purpose derivative rules, and importing from `ChainRules`.
281+
282+
### Inactive Annotations
283+
284+
The simplest custom derivative is simply telling Enzyme that a given function does not need to be differentiated. For example, consider computing `det(Unitary Matrix)`. The determinant is always 1 so the derivative is always zero. Without this high level mathematical insight, the default rule Enzyme generates will add up a bunch of numbers that eventually come to zero. Instead of unnecessarily doing this work, we can just tell Enzyme that the derivative is always zero.
285+
286+
In autodiff-parlance we are telling Enzyme that the given result is `inactive` (aka makes no impact on the derivative). This can be done as follows:
287+
288+
```julia
289+
290+
# Our existing function and types
291+
struct UnitaryMatrix
292+
...
293+
end
294+
295+
det(::UnitaryMatrix) = ...
296+
297+
using Enzyme.EnzymeRules
298+
299+
EnzymeRules.inactive(::typeof(det), ::UnitaryMatrix) = true
300+
```
301+
302+
Specifically, we define a new overload of the method [`EnzymeRules.inactive`](@ref) where the first argument is the type of the function being marked inactive, and the corresponding arguments match the arguments we want to overload the method for. This enables us, for example, to only mark the determinant of the `UnitaryMatrix` class here as inactive, and not the determinant of a general Matrix.
303+
304+
Enzyme also supports a second way to mark things inactive, where the marker is "less strong" and not guaranteed to apply if other optimizations might otherwise simplify the code first.
305+
306+
```julia
307+
EnzymeRules.inactive_noinl(::typeof(det), ::UnitaryMatrix) = true
308+
```
309+
310+
### Easy Rules
311+
312+
The recommended way for writing rules for most use cases is through the [`EnzymeRules.@easy_rule`](@ref) macro. This macro enables users to write derivatives for any functions which only read from their arguments (e.g. do not overwrite memory), and has numbers, matricies of numbers, or tuples thereof as arguments/result types.
313+
314+
When writing an [`EnzymeRules.@easy_rule`](@ref) one first describes the function signature one wants the derivative rule to apply to. In each subsequent line, one should write a tuple, where each element of the tuple represents the derivative of the corresponding input argument. In that sense writing an [`EnzymeRules.@easy_rule`](@ref) is equivalent to specifying the Jacobian. Inside of this tuple, one can call arbitrary Julia code.
315+
316+
One can also define certain arguments as not having a derivative via `@Constant`.
317+
318+
For more information see the [`EnzymeRules.@easy_rule`](@ref) documentation.
319+
320+
```jldoctest easyrules
321+
using Enzyme
322+
323+
function f(x, y)
324+
return (x*x, cos(y) * x)
325+
end
326+
327+
Enzyme.EnzymeRules.@easy_rule(f(x,y),
328+
# df1/dx, #df1/dy
329+
(2*x, @Constant),
330+
# df2/dx, #df2/dy
331+
(cos(y), x * sin(y))
332+
)
333+
334+
function g(x, y)
335+
return f(x, y)[2]
336+
end
337+
338+
Enzyme.gradient(Reverse, g, 2.0, 3.0)
339+
340+
# output
341+
(-0.9899924966004454, 0.2822400161197344)
342+
```
343+
344+
Enzyme will automatically generate efficient derivatives for forward mode, reverse mode, batched forward and reverse mode, overwritten data, inactive inputs, and more from the given specification macro.
345+
346+
### General Purpose EnzymeRules
347+
348+
Finally Enzyme supports general-purpose EnzymeRules. For a given function, one can specify arbitrary behavior to occur when differentiting a given function. This is useful if you want to write efficient derivatives for mutating code, are handling funky behavior like GPU/distributed runtime calls, and more.
349+
350+
Like before, Enzyme takes a specification of the function the rule applies to, and passes various configuration data for full user-level customization.
351+
352+
```jldoctest genrules
353+
using Enzyme
354+
355+
function mysin(x)
356+
return sin(x)
357+
end
358+
359+
function Enzyme.EnzymeRules.forward(config, ::Const{typeof(mysin)}, ::Type, x)
360+
# If we don't need the original result, let's avoid computing it (and print)
361+
if !needs_primal(config)
362+
println("Avoiding computing sin!")
363+
return cos(x.val) * x.dval
364+
else
365+
println("Still computing sin")
366+
return Duplicated(sin(x.val), cos(x.val) * x.dval)
367+
end
368+
end
369+
370+
function mysquare(x)
371+
y = mysin(x)
372+
return y*y
373+
end
374+
375+
# Prints "Avoiding computing sin!"
376+
Enzyme.gradient(Forward, mysin, 2.0);
377+
378+
# Prints "Still computing sin =/" as d/dx sin(x)^2 = 2 * sin(x) * sin'(x)
379+
# so the original result is still needed
380+
Enzyme.gradient(Forward, mysquare, 2.0);
381+
382+
# output
383+
Avoiding computing sin!
384+
Still computing sin
385+
(-0.7568024953079283,)
386+
```
387+
388+
For more information, see [the custom rule docs](@ref custom_rules), [`EnzymeRules.forward`](@ref), [`EnzymeRules.augmented_primal`](@ref), and [`EnzymeRules.reverse`](@ref).
389+
390+
### Importing ChainRules
391+
392+
Enzyme can also import rules from the `ChainRules` ecosystem. This is often helpful when first getting started, though it will generally be much more efficient to write either an [`EnzymeRules.@easy_rule`](@ref) or general custom rule.
393+
394+
Enzyme can import the forward rule, reverse rule, or both.
395+
396+
```jldoctest chainrule
397+
using Enzyme, ChainRulesCore
398+
399+
f(x) = sin(x)
400+
ChainRulesCore.@scalar_rule f(x) (cos(x),)
401+
402+
# Import the reverse rule for float32
403+
Enzyme.@import_rrule typeof(f) Float32
404+
405+
# Import the forward rule for float32
406+
Enzyme.@import_frule typeof(f) Float32
407+
408+
# output
409+
```
410+
411+
See the docs on [`Enzyme.@import_frule`](@ref) and [`Enzyme.@import_rrule`](@ref) for more information.

examples/custom_rule.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# # Enzyme custom rules tutorial
1+
# # [Enzyme custom rules tutorial](@id custom_rules)
22
#
33
# !!! note "More Examples"
44
# The tutorial below focuses on a simple setting to illustrate the basic concepts of writing custom rules.

0 commit comments

Comments
 (0)