@@ -2,7 +2,13 @@ module Training
22
33using Adapt: Adapt
44using ADTypes:
5- AbstractADType, AutoEnzyme, AutoReverseDiff, AutoTracker, AutoZygote, AutoMooncake
5+ AbstractADType,
6+ AutoEnzyme,
7+ AutoReverseDiff,
8+ AutoTracker,
9+ AutoZygote,
10+ AutoMooncake,
11+ AutoReactant
612using SciMLPublic: @public
713using ConcreteStructs: @concrete
814using FastClosures: @closure
161167
162168@concrete struct ReactantBackend
163169 return_gradients <: StaticBool
164- sync:: Bool
170+ sync <: Union{Bool,Missing}
171+ compile_options
165172 ad <: AutoEnzyme
166173end
167174
@@ -247,10 +254,15 @@ const SYNC_DOCSTRING = """
247254 Reactant Backend.
248255"""
249256
257+ const COMPILE_OPTIONS_DOCSTRING = """
258+ - `compile_options`: Compile options for the reactant function. See
259+ `Reactant.CompileOptions` for more details. This is only used for Reactant Backend.
260+ """
261+
250262"""
251263 compute_gradients(
252264 ad::AbstractADType, objective_function::Function, data, ts::TrainState;
253- sync::Bool=false
265+ sync::Bool=false, compile_options::Union{Missing,Reactant.CompileOptions}=missing
254266 )
255267
256268Compute the gradients of the objective function wrt parameters stored in `ts`.
@@ -279,6 +291,7 @@ Compute the gradients of the objective function wrt parameters stored in `ts`.
279291## Keyword Arguments
280292
281293$(SYNC_DOCSTRING)
294+ $(COMPILE_OPTIONS_DOCSTRING)
282295
283296## Return
284297
@@ -304,10 +317,10 @@ A 4-Tuple containing:
304317 returned in step `i + 1` might be aliased by the old gradients. If you want to prevent
305318 this, simply use `copy(grads)` or `deepcopy(grads)` to make a copy of the gradients.
306319"""
307- function compute_gradients(ad, obj_fn:: F , data, ts:: TrainState ; sync :: Bool = false ) where {F}
320+ function compute_gradients(ad, obj_fn:: F , data, ts:: TrainState ; kwargs ... ) where {F}
308321 dev_type = get_device_type((ts. parameters, ts. states))
309322 return compute_gradients_impl_with_allocator_cache(
310- maybe_wrap_adtype(ad, dev_type; sync ), ts. allocator_cache, obj_fn, data, ts
323+ maybe_wrap_adtype(ad, dev_type; kwargs ... ), ts. allocator_cache, obj_fn, data, ts
311324 )
312325end
313326
@@ -346,14 +359,33 @@ end
346359maybe_wrap_adtype(backend:: ReactantBackend , :: Any ; kwargs... ) = backend
347360maybe_wrap_adtype(ad:: AbstractADType , :: Any ; kwargs... ) = ad
348361function maybe_wrap_adtype(
349- ad:: AbstractADType ,
362+ ad:: AutoEnzyme ,
363+ :: Type{ReactantDevice} ;
364+ return_gradients:: Utils.BoolType = True(),
365+ sync:: Union{Missing,Bool} = missing ,
366+ compile_options= nothing ,
367+ )
368+ return ReactantBackend(static(return_gradients), sync, compile_options, ad)
369+ end
370+ function maybe_wrap_adtype(
371+ ad:: AutoReactant ,
350372 :: Type{ReactantDevice} ;
351373 return_gradients:: Utils.BoolType = True(),
352- sync:: Bool = false ,
374+ sync:: Union{Missing,Bool} = missing ,
375+ compile_options= nothing ,
353376)
354- ad isa AutoEnzyme && return ReactantBackend(static(return_gradients), sync, ad)
355- throw(ArgumentError(" Computing gradients for models on XLA is supported only with \
356- Enzyme.jl (`AutoEnzyme`)." ))
377+ return ReactantBackend(static(return_gradients), sync, compile_options, ad. mode)
378+ end
379+ function maybe_wrap_adtype(ad:: AutoReactant , :: Type{T} ; kwargs... ) where {T}
380+ throw(ArgumentError(" `AutoReactant` only supports ReactantDevice but got `$(T) `" ))
381+ end
382+ function maybe_wrap_adtype(ad:: AbstractADType , :: Type{ReactantDevice} ; kwargs... )
383+ throw(
384+ ArgumentError(
385+ " Computing gradients for models with Reactant is supported only with \
386+ Enzyme.jl (`AutoEnzyme` or `AutoReactant`)."
387+ ),
388+ )
357389end
358390
359391function generate_wrappers(:: F , m, ps, st, data, :: False , :: StaticBool ) where {F}
@@ -408,7 +440,9 @@ const RETURN_GRADIENTS_DOCSTRING = """
408440
409441"""
410442 single_train_step!(
411- backend, obj_fn::F, data, ts::TrainState; return_gradients=True(), sync::Bool=false
443+ backend, obj_fn::F, data, ts::TrainState;
444+ return_gradients=True(), sync::Bool=false,
445+ compile_options::Union{Nothing,Reactant.CompileOptions}=missing,
412446 )
413447
414448Perform a single training step. Computes the gradients using [`compute_gradients`](@ref) and
@@ -419,6 +453,7 @@ updates the parameters using [`apply_gradients!`](@ref). All backends supported
419453
420454$(RETURN_GRADIENTS_DOCSTRING)
421455$(SYNC_DOCSTRING)
456+ $(COMPILE_OPTIONS_DOCSTRING)
422457
423458## Return
424459
@@ -427,16 +462,9 @@ only the parameters in `ts` are updated inplace. Users should be using the retur
427462object for further training steps, else there is no caching and performance will be
428463suboptimal (and absolutely terrible for backends like `AutoReactant`).
429464"""
430- function single_train_step!(
431- backend,
432- obj_fn:: F ,
433- data,
434- ts:: TrainState ;
435- return_gradients:: Utils.BoolType = True(),
436- sync:: Bool = false ,
437- ) where {F}
465+ function single_train_step!(backend, obj_fn:: F , data, ts:: TrainState ; kwargs... ) where {F}
438466 backend = maybe_wrap_adtype(
439- backend, get_device_type((ts. parameters, ts. states)); return_gradients, sync
467+ backend, get_device_type((ts. parameters, ts. states)); kwargs ...
440468 )
441469 return single_train_step_impl_with_allocator_cache!(
442470 backend, ts. allocator_cache, obj_fn, data, ts
445473
446474"""
447475 single_train_step(
448- backend, obj_fn::F, data, ts::TrainState; return_gradients=True(), sync::Bool=false
476+ backend, obj_fn::F, data, ts::TrainState;
477+ return_gradients=True(), sync::Bool=false,
478+ compile_options::Union{Nothing,Reactant.CompileOptions}=missing,
449479 )
450480
451481Perform a single training step. Computes the gradients using [`compute_gradients`](@ref) and
@@ -458,21 +488,15 @@ In most cases you should use [`single_train_step!`](@ref) instead of this functi
458488
459489$(RETURN_GRADIENTS_DOCSTRING)
460490$(SYNC_DOCSTRING)
491+ $(COMPILE_OPTIONS_DOCSTRING)
461492
462493## Return
463494
464495Returned values are the same as [`single_train_step!`](@ref).
465496"""
466- function single_train_step(
467- backend,
468- obj_fn:: F ,
469- data,
470- ts:: TrainState ;
471- return_gradients:: Utils.BoolType = True(),
472- sync:: Bool = false ,
473- ) where {F}
497+ function single_train_step(backend, obj_fn:: F , data, ts:: TrainState ; kwargs... ) where {F}
474498 backend = maybe_wrap_adtype(
475- backend, get_device_type((ts. parameters, ts. states)); return_gradients, sync
499+ backend, get_device_type((ts. parameters, ts. states)); kwargs ...
476500 )
477501 return single_train_step_impl(backend, obj_fn, data, ts)
478502end
0 commit comments