@@ -84,13 +84,15 @@ struct HVPGradientHessianPrep{
8484 BS<: BatchSizeSettings ,
8585 S<: AbstractVector{<:NTuple} ,
8686 R<: AbstractVector{<:NTuple} ,
87+ SE<: NTuple ,
8788 E2<: HVPPrep ,
8889 E1<: GradientPrep ,
8990} <: HessianPrep{SIG}
9091 _sig:: Val{SIG}
9192 batch_size_settings:: BS
9293 batched_seeds:: S
9394 batched_results:: R
95+ seed_example:: SE
9496 hvp_prep:: E2
9597 gradient_prep:: E1
9698end
@@ -119,10 +121,17 @@ function _prepare_hessian_aux(
119121 ntuple (b -> seeds[1 + ((a - 1 ) * B + (b - 1 )) % N], Val (B)) for a in 1 : A
120122 ]
121123 batched_results = [ntuple (b -> similar (x), Val (B)) for _ in batched_seeds]
122- hvp_prep = prepare_hvp_nokwarg (strict, f, backend, x, batched_seeds[1 ], contexts... )
124+ seed_example = ntuple (b -> basis (x), Val (B))
125+ hvp_prep = prepare_hvp_nokwarg (strict, f, backend, x, seed_example, contexts... )
123126 gradient_prep = prepare_gradient_nokwarg (strict, f, inner (backend), x, contexts... )
124127 return HVPGradientHessianPrep (
125- _sig, batch_size_settings, batched_seeds, batched_results, hvp_prep, gradient_prep
128+ _sig,
129+ batch_size_settings,
130+ batched_seeds,
131+ batched_results,
132+ seed_example,
133+ hvp_prep,
134+ gradient_prep,
126135 )
127136end
128137
@@ -150,11 +159,11 @@ function hessian(
150159 contexts:: Vararg{Context,C} ,
151160) where {F,SIG,B,aligned,C}
152161 check_prep (f, prep, backend, x, contexts... )
153- (; batch_size_settings, batched_seeds, hvp_prep) = prep
162+ (; batch_size_settings, batched_seeds, seed_example, hvp_prep) = prep
154163 (; A, B_last) = batch_size_settings
155164
156165 hvp_prep_same = prepare_hvp_same_point (
157- f, hvp_prep, backend, x, batched_seeds[ 1 ] , contexts...
166+ f, hvp_prep, backend, x, seed_example , contexts...
158167 )
159168
160169 hess = mapreduce (hcat, eachindex (batched_seeds)) do a
@@ -178,11 +187,11 @@ function hessian!(
178187 contexts:: Vararg{Context,C} ,
179188) where {F,SIG,B,C}
180189 check_prep (f, prep, backend, x, contexts... )
181- (; batch_size_settings, batched_seeds, batched_results, hvp_prep) = prep
190+ (; batch_size_settings, batched_seeds, batched_results, seed_example, hvp_prep) = prep
182191 (; N) = batch_size_settings
183192
184193 hvp_prep_same = prepare_hvp_same_point (
185- f, hvp_prep, backend, x, batched_seeds[ 1 ] , contexts...
194+ f, hvp_prep, backend, x, seed_example , contexts...
186195 )
187196
188197 for a in eachindex (batched_seeds, batched_results)
0 commit comments