You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
While the convenience functions discussed below use [`autodiff`](@ref) internally, they are generally more limited in their functionality. Beyond that, these convenience functions may also come with performance penalties; especially if one makes a closure of a multi-argument function instead of calling the appropriate multi-argument [`autodiff`](@ref) function directly.
@@ -189,7 +190,7 @@ julia> # in forward mode, we can also optionally pass a chunk size
189
190
([-400.0, 200.0],)
190
191
```
191
192
192
-
## Jacobian Convenience functions
193
+
###Jacobian Convenience functions
193
194
194
195
The function [`jacobian`](@ref) computes the Jacobian of a function vector input and vector return.
195
196
Like [`autodiff`](@ref) and [`gradient`](@ref), the mode (forward or reverse) is determined by the first argument.
Enzyme provides convenience functions for second-order derivative computations, like [`hvp`](@ref) to compute Hessian vector products. Mathematically, this computes $H(x) v$, where $H$ is the hessian operator.
227
228
@@ -273,3 +274,138 @@ julia> grad
273
274
2.880510859951098
274
275
1.920340573300732
275
276
```
277
+
278
+
## Defining rules
279
+
280
+
While Enzyme will automatically generate derivative functions for you, there may be instances in which it is necessary or helpful to define custom derivative rules. Enzyme has three primary ways for defining derivative rules: inactive annotations, [`EnzymeRules.@easy_rule`](@ref) macro definitions, general purpose derivative rules, and importing from `ChainRules`.
281
+
282
+
### Inactive Annotations
283
+
284
+
The simplest custom derivative is simply telling Enzyme that a given function does not need to be differentiated. For example, consider computing `det(Unitary Matrix)`. The determinant is always 1 so the derivative is always zero. Without this high level mathematical insight, the default rule Enzyme generates will add up a bunch of numbers that eventually come to zero. Instead of unnecessarily doing this work, we can just tell Enzyme that the derivative is always zero.
285
+
286
+
In autodiff-parlance we are telling Enzyme that the given result is `inactive` (aka makes no impact on the derivative). This can be done as follows:
Specifically, we define a new overload of the method [`EnzymeRules.inactive`](@ref) where the first argument is the type of the function being marked inactive, and the corresponding arguments match the arguments we want to overload the method for. This enables us, for example, to only mark the determinant of the `UnitaryMatrix` class here as inactive, and not the determinant of a general Matrix.
303
+
304
+
Enzyme also supports a second way to mark things inactive, where the marker is "less strong" and not guaranteed to apply if other optimizations might otherwise simplify the code first.
The recommended way for writing rules for most use cases is through the [`EnzymeRules.@easy_rule`](@ref) macro. This macro enables users to write derivatives for any functions which only read from their arguments (e.g. do not overwrite memory), and has numbers, matricies of numbers, or tuples thereof as arguments/result types.
313
+
314
+
When writing an [`EnzymeRules.@easy_rule`](@ref) one first describes the function signature one wants the derivative rule to apply to. In each subsequent line, one should write a tuple, where each element of the tuple represents the derivative of the corresponding input argument. In that sense writing an [`EnzymeRules.@easy_rule`](@ref) is equivalent to specifying the Jacobian. Inside of this tuple, one can call arbitrary Julia code.
315
+
316
+
One can also define certain arguments as not having a derivative via `@Constant`.
317
+
318
+
For more information see the [`EnzymeRules.@easy_rule`](@ref) documentation.
319
+
320
+
```jldoctest easyrules
321
+
using Enzyme
322
+
323
+
function f(x, y)
324
+
return (x*x, cos(y) * x)
325
+
end
326
+
327
+
Enzyme.EnzymeRules.@easy_rule(f(x,y),
328
+
# df1/dx, #df1/dy
329
+
(2*x, @Constant),
330
+
# df2/dx, #df2/dy
331
+
(cos(y), x * sin(y))
332
+
)
333
+
334
+
function g(x, y)
335
+
return f(x, y)[2]
336
+
end
337
+
338
+
Enzyme.gradient(Reverse, g, 2.0, 3.0)
339
+
340
+
# output
341
+
(-0.9899924966004454, 0.2822400161197344)
342
+
```
343
+
344
+
Enzyme will automatically generate efficient derivatives for forward mode, reverse mode, batched forward and reverse mode, overwritten data, inactive inputs, and more from the given specification macro.
345
+
346
+
### General Purpose EnzymeRules
347
+
348
+
Finally Enzyme supports general-purpose EnzymeRules. For a given function, one can specify arbitrary behavior to occur when differentiting a given function. This is useful if you want to write efficient derivatives for mutating code, are handling funky behavior like GPU/distributed runtime calls, and more.
349
+
350
+
Like before, Enzyme takes a specification of the function the rule applies to, and passes various configuration data for full user-level customization.
351
+
352
+
```jldoctest genrules
353
+
using Enzyme
354
+
355
+
function mysin(x)
356
+
return sin(x)
357
+
end
358
+
359
+
function Enzyme.EnzymeRules.forward(config, ::Const{typeof(mysin)}, ::Type, x)
360
+
# If we don't need the original result, let's avoid computing it (and print)
# Prints "Still computing sin =/" as d/dx sin(x)^2 = 2 * sin(x) * sin'(x)
379
+
# so the original result is still needed
380
+
Enzyme.gradient(Forward, mysquare, 2.0);
381
+
382
+
# output
383
+
Avoiding computing sin!
384
+
Still computing sin
385
+
(-0.7568024953079283,)
386
+
```
387
+
388
+
For more information, see [the custom rule docs](@ref custom_rules), [`EnzymeRules.forward`](@ref), [`EnzymeRules.augmented_primal`](@ref), and [`EnzymeRules.reverse`](@ref).
389
+
390
+
### Importing ChainRules
391
+
392
+
Enzyme can also import rules from the `ChainRules` ecosystem. This is often helpful when first getting started, though it will generally be much more efficient to write either an [`EnzymeRules.@easy_rule`](@ref) or general custom rule.
393
+
394
+
Enzyme can import the forward rule, reverse rule, or both.
395
+
396
+
```jldoctest chainrule
397
+
using Enzyme, ChainRulesCore
398
+
399
+
f(x) = sin(x)
400
+
ChainRulesCore.@scalar_rule f(x) (cos(x),)
401
+
402
+
# Import the reverse rule for float32
403
+
Enzyme.@import_rrule typeof(f) Float32
404
+
405
+
# Import the forward rule for float32
406
+
Enzyme.@import_frule typeof(f) Float32
407
+
408
+
# output
409
+
```
410
+
411
+
See the docs on [`Enzyme.@import_frule`](@ref) and [`Enzyme.@import_rrule`](@ref) for more information.
0 commit comments