@@ -119,8 +119,11 @@ function EnzymeForwardGradientPrep(::Val{B}, shadows::O) where {B,O}
119119end
120120
121121function DI. prepare_gradient (
122- f:: F , backend:: AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}} , x
123- ) where {F}
122+ f:: F ,
123+ backend:: AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}} ,
124+ x,
125+ contexts:: Vararg{DI.Constant,C} ,
126+ ) where {F,C}
124127 valB = to_val (DI. pick_batchsize (backend, x))
125128 shadows = create_shadows (valB, x)
126129 return EnzymeForwardGradientPrep (valB, shadows)
@@ -131,23 +134,31 @@ function DI.gradient(
131134 prep:: EnzymeForwardGradientPrep{B} ,
132135 backend:: AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}} ,
133136 x,
134- ) where {F,B}
137+ contexts:: Vararg{DI.Constant,C} ,
138+ ) where {F,B,C}
135139 mode = forward_noprimal (backend)
136140 f_and_df = get_f_and_df (f, backend, mode)
137- derivs = gradient (mode, f_and_df, x; chunk= Val (B), shadows= prep. shadows)
138- return only (derivs)
141+ annotated_contexts = translate (backend, mode, Val (B), contexts... )
142+ derivs = gradient (
143+ mode, f_and_df, x, annotated_contexts... ; chunk= Val (B), shadows= prep. shadows
144+ )
145+ return first (derivs)
139146end
140147
141148function DI. value_and_gradient (
142149 f:: F ,
143150 prep:: EnzymeForwardGradientPrep{B} ,
144151 backend:: AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}} ,
145152 x,
146- ) where {F,B}
153+ contexts:: Vararg{DI.Constant,C} ,
154+ ) where {F,B,C}
147155 mode = forward_withprimal (backend)
148156 f_and_df = get_f_and_df (f, backend, mode)
149- (; derivs, val) = gradient (mode, f_and_df, x; chunk= Val (B), shadows= prep. shadows)
150- return val, only (derivs)
157+ annotated_contexts = translate (backend, mode, Val (B), contexts... )
158+ (; derivs, val) = gradient (
159+ mode, f_and_df, x, annotated_contexts... ; chunk= Val (B), shadows= prep. shadows
160+ )
161+ return val, first (derivs)
151162end
152163
153164function DI. gradient! (
@@ -156,8 +167,9 @@ function DI.gradient!(
156167 prep:: EnzymeForwardGradientPrep{B} ,
157168 backend:: AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}} ,
158169 x,
159- ) where {F,B}
160- return copyto! (grad, DI. gradient (f, prep, backend, x))
170+ contexts:: Vararg{DI.Constant,C} ,
171+ ) where {F,B,C}
172+ return copyto! (grad, DI. gradient (f, prep, backend, x, contexts... ))
161173end
162174
163175function DI. value_and_gradient! (
@@ -166,8 +178,9 @@ function DI.value_and_gradient!(
166178 prep:: EnzymeForwardGradientPrep{B} ,
167179 backend:: AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}} ,
168180 x,
169- ) where {F,B}
170- y, new_grad = DI. value_and_gradient (f, prep, backend, x)
181+ contexts:: Vararg{DI.Constant,C} ,
182+ ) where {F,B,C}
183+ y, new_grad = DI. value_and_gradient (f, prep, backend, x, contexts... )
171184 return y, copyto! (grad, new_grad)
172185end
173186
@@ -185,9 +198,12 @@ function EnzymeForwardOneArgJacobianPrep(
185198end
186199
187200function DI. prepare_jacobian (
188- f:: F , backend:: AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}} , x
189- ) where {F}
190- y = f (x)
201+ f:: F ,
202+ backend:: AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}} ,
203+ x,
204+ contexts:: Vararg{DI.Constant,C} ,
205+ ) where {F,C}
206+ y = f (x, map (DI. unwrap, contexts)... )
191207 valB = to_val (DI. pick_batchsize (backend, x))
192208 shadows = create_shadows (valB, x)
193209 return EnzymeForwardOneArgJacobianPrep (valB, shadows, length (y))
@@ -198,11 +214,15 @@ function DI.jacobian(
198214 prep:: EnzymeForwardOneArgJacobianPrep{B} ,
199215 backend:: AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}} ,
200216 x,
201- ) where {F,B}
217+ contexts:: Vararg{DI.Constant,C} ,
218+ ) where {F,B,C}
202219 mode = forward_noprimal (backend)
203220 f_and_df = get_f_and_df (f, backend, mode)
204- derivs = jacobian (mode, f_and_df, x; chunk= Val (B), shadows= prep. shadows)
205- jac_tensor = only (derivs)
221+ annotated_contexts = translate (backend, mode, Val (B), contexts... )
222+ derivs = jacobian (
223+ mode, f_and_df, x, annotated_contexts... ; chunk= Val (B), shadows= prep. shadows
224+ )
225+ jac_tensor = first (derivs)
206226 return maybe_reshape (jac_tensor, prep. output_length, length (x))
207227end
208228
@@ -211,11 +231,15 @@ function DI.value_and_jacobian(
211231 prep:: EnzymeForwardOneArgJacobianPrep{B} ,
212232 backend:: AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}} ,
213233 x,
214- ) where {F,B}
234+ contexts:: Vararg{DI.Constant,C} ,
235+ ) where {F,B,C}
215236 mode = forward_withprimal (backend)
216237 f_and_df = get_f_and_df (f, backend, mode)
217- (; derivs, val) = jacobian (mode, f_and_df, x; chunk= Val (B), shadows= prep. shadows)
218- jac_tensor = only (derivs)
238+ annotated_contexts = translate (backend, mode, Val (B), contexts... )
239+ (; derivs, val) = jacobian (
240+ mode, f_and_df, x, annotated_contexts... ; chunk= Val (B), shadows= prep. shadows
241+ )
242+ jac_tensor = first (derivs)
219243 return val, maybe_reshape (jac_tensor, prep. output_length, length (x))
220244end
221245
@@ -225,8 +249,9 @@ function DI.jacobian!(
225249 prep:: EnzymeForwardOneArgJacobianPrep ,
226250 backend:: AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}} ,
227251 x,
228- ) where {F}
229- return copyto! (jac, DI. jacobian (f, prep, backend, x))
252+ contexts:: Vararg{DI.Constant,C} ,
253+ ) where {F,C}
254+ return copyto! (jac, DI. jacobian (f, prep, backend, x, contexts... ))
230255end
231256
232257function DI. value_and_jacobian! (
@@ -235,7 +260,8 @@ function DI.value_and_jacobian!(
235260 prep:: EnzymeForwardOneArgJacobianPrep ,
236261 backend:: AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}} ,
237262 x,
238- ) where {F}
239- y, new_jac = DI. value_and_jacobian (f, prep, backend, x)
263+ contexts:: Vararg{DI.Constant,C} ,
264+ ) where {F,C}
265+ y, new_jac = DI. value_and_jacobian (f, prep, backend, x, contexts... )
240266 return y, copyto! (jac, new_jac)
241267end
0 commit comments