@@ -89,3 +89,113 @@ function Training.compute_gradients_impl(
8989 ts,
9090 )
9191end
92+
93+ # Type Piracy for ForwardDiff GPU Array Support
94+ # This is a workaround for ForwardDiff.jl not supporting GPU arrays post v1.0
95+ # See: https://github.com/JuliaDiff/ForwardDiff.jl/pull/760
96+
97+ using GPUArraysCore: AnyGPUArray
98+
99+ # Helper struct for broadcasting partials extraction
100+ struct PartialsFn{T,D<: ForwardDiff.Dual }
101+ dual:: D
102+ end
103+
104+ PartialsFn{T}(dual:: ForwardDiff.Dual ) where {T} = PartialsFn{T,typeof(dual)}(dual)
105+
106+ (f:: PartialsFn{T} )(i) where {T} = ForwardDiff. partials(T, f. dual, i)
107+
108+ # Macro to define ForwardDiff overloads for array types that don't support scalar indexing
109+ macro define_forwarddiff_gpu_overloads(ArrayType)
110+ return quote
111+ # Overloaded seed! methods
112+ function ForwardDiff. seed!(
113+ duals:: $ (esc(ArrayType)){ForwardDiff. Dual{T,V,N}},
114+ x,
115+ seed:: ForwardDiff.Partials{N,V} = zero(ForwardDiff. Partials{N,V}),
116+ ) where {T,V,N}
117+ idxs = collect(ForwardDiff. structural_eachindex(duals, x))
118+ duals[idxs] .= ForwardDiff. Dual{T,V,N}. (view(x, idxs), Ref(seed))
119+ return duals
120+ end
121+
122+ function ForwardDiff. seed!(
123+ duals:: $ (esc(ArrayType)){ForwardDiff. Dual{T,V,N}},
124+ x,
125+ seeds:: NTuple{N,ForwardDiff.Partials{N,V}} ,
126+ ) where {T,V,N}
127+ idxs = collect(Iterators. take(ForwardDiff. structural_eachindex(duals, x), N))
128+ duals[idxs] .=
129+ ForwardDiff. Dual{
130+ T,V,N
131+ }. (view(x, idxs), getindex.(Ref(seeds), 1 : length(idxs)))
132+ return duals
133+ end
134+
135+ function ForwardDiff. seed!(
136+ duals:: $ (esc(ArrayType)){ForwardDiff. Dual{T,V,N}},
137+ x,
138+ index,
139+ seed:: ForwardDiff.Partials{N,V} = zero(ForwardDiff. Partials{N,V}),
140+ ) where {T,V,N}
141+ idxs = collect(
142+ Iterators. drop(ForwardDiff. structural_eachindex(duals, x), index - 1 )
143+ )
144+ duals[idxs] .= ForwardDiff. Dual{T,V,N}. (view(x, idxs), Ref(seed))
145+ return duals
146+ end
147+
148+ function ForwardDiff. seed!(
149+ duals:: $ (esc(ArrayType)){ForwardDiff. Dual{T,V,N}},
150+ x,
151+ index,
152+ seeds:: NTuple{N,ForwardDiff.Partials{N,V}} ,
153+ chunksize= N,
154+ ) where {T,V,N}
155+ idxs = collect(
156+ Iterators. take(
157+ Iterators. drop(ForwardDiff. structural_eachindex(duals, x), index - 1 ),
158+ chunksize,
159+ ),
160+ )
161+ duals[idxs] .=
162+ ForwardDiff. Dual{
163+ T,V,N
164+ }. (view(x, idxs), getindex.(Ref(seeds), 1 : length(idxs)))
165+ return duals
166+ end
167+
168+ # Overloaded extract_gradient! methods
169+ function ForwardDiff. extract_gradient!(
170+ :: Type{T} , result:: $ (esc(ArrayType)), dual:: ForwardDiff.Dual
171+ ) where {T}
172+ fn = PartialsFn{T}(dual)
173+ idxs = collect(
174+ Iterators. take(
175+ ForwardDiff. structural_eachindex(result), ForwardDiff. npartials(dual)
176+ ),
177+ )
178+ result[idxs] .= fn.(1 : length(idxs))
179+ return result
180+ end
181+
182+ function ForwardDiff. extract_gradient_chunk!(
183+ :: Type{T} , result:: $ (esc(ArrayType)), dual, index, chunksize
184+ ) where {T}
185+ fn = PartialsFn{T}(dual)
186+ idxs = collect(
187+ Iterators. take(
188+ Iterators. drop(ForwardDiff. structural_eachindex(result), index - 1 ),
189+ chunksize,
190+ ),
191+ )
192+ result[idxs] .= fn.(1 : length(idxs))
193+ return result
194+ end
195+ end
196+ end
197+
198+ @static if pkgversion(ForwardDiff) ≥ v" 1.0.1"
199+ # Apply overloads for GPU arrays
200+ @define_forwarddiff_gpu_overloads AnyGPUArray
201+ end
0 commit comments