Skip to content

Commit 22d9305

Browse files
authored
fix: forwarddiff support for gpu arrays (#1605)
1 parent 57c4b80 commit 22d9305

File tree

4 files changed

+128
-3
lines changed

4 files changed

+128
-3
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Lux"
22
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
33
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
4-
version = "1.28.0"
4+
version = "1.29.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -101,7 +101,7 @@ Enzyme = "0.13.81"
101101
EnzymeCore = "0.8.15"
102102
FastClosures = "0.3.2"
103103
Flux = "0.16.3"
104-
ForwardDiff = "0.10.36, =1"
104+
ForwardDiff = "0.10.36, 1"
105105
FunctionWrappers = "1.1.3"
106106
Functors = "0.5"
107107
GPUArrays = "11"

ext/LuxComponentArraysExt.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module LuxComponentArraysExt
22

33
using ComponentArrays: ComponentArrays, ComponentArray
44
using Lux: Lux, DistributedUtils
5+
using ForwardDiff: ForwardDiff
56

67
# Distributed Functionality
78
function DistributedUtils.synchronize!!(
@@ -11,4 +12,9 @@ function DistributedUtils.synchronize!!(
1112
return ComponentArray(ps_synced, ComponentArrays.getaxes(ps))
1213
end
1314

15+
@static if pkgversion(ForwardDiff) v"1.0.1"
16+
# Apply overloads for GPU arrays
17+
Lux.@define_forwarddiff_gpu_overloads ComponentArray
18+
end
19+
1420
end

src/helpers/forwarddiff_training.jl

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,3 +89,113 @@ function Training.compute_gradients_impl(
8989
ts,
9090
)
9191
end
92+
93+
# Type Piracy for ForwardDiff GPU Array Support
94+
# This is a workaround for ForwardDiff.jl not supporting GPU arrays post v1.0
95+
# See: https://github.com/JuliaDiff/ForwardDiff.jl/pull/760
96+
97+
using GPUArraysCore: AnyGPUArray
98+
99+
# Helper struct for broadcasting partials extraction
100+
struct PartialsFn{T,D<:ForwardDiff.Dual}
101+
dual::D
102+
end
103+
104+
PartialsFn{T}(dual::ForwardDiff.Dual) where {T} = PartialsFn{T,typeof(dual)}(dual)
105+
106+
(f::PartialsFn{T})(i) where {T} = ForwardDiff.partials(T, f.dual, i)
107+
108+
# Macro to define ForwardDiff overloads for array types that don't support scalar indexing
109+
macro define_forwarddiff_gpu_overloads(ArrayType)
110+
return quote
111+
# Overloaded seed! methods
112+
function ForwardDiff.seed!(
113+
duals::$(esc(ArrayType)){ForwardDiff.Dual{T,V,N}},
114+
x,
115+
seed::ForwardDiff.Partials{N,V}=zero(ForwardDiff.Partials{N,V}),
116+
) where {T,V,N}
117+
idxs = collect(ForwardDiff.structural_eachindex(duals, x))
118+
duals[idxs] .= ForwardDiff.Dual{T,V,N}.(view(x, idxs), Ref(seed))
119+
return duals
120+
end
121+
122+
function ForwardDiff.seed!(
123+
duals::$(esc(ArrayType)){ForwardDiff.Dual{T,V,N}},
124+
x,
125+
seeds::NTuple{N,ForwardDiff.Partials{N,V}},
126+
) where {T,V,N}
127+
idxs = collect(Iterators.take(ForwardDiff.structural_eachindex(duals, x), N))
128+
duals[idxs] .=
129+
ForwardDiff.Dual{
130+
T,V,N
131+
}.(view(x, idxs), getindex.(Ref(seeds), 1:length(idxs)))
132+
return duals
133+
end
134+
135+
function ForwardDiff.seed!(
136+
duals::$(esc(ArrayType)){ForwardDiff.Dual{T,V,N}},
137+
x,
138+
index,
139+
seed::ForwardDiff.Partials{N,V}=zero(ForwardDiff.Partials{N,V}),
140+
) where {T,V,N}
141+
idxs = collect(
142+
Iterators.drop(ForwardDiff.structural_eachindex(duals, x), index - 1)
143+
)
144+
duals[idxs] .= ForwardDiff.Dual{T,V,N}.(view(x, idxs), Ref(seed))
145+
return duals
146+
end
147+
148+
function ForwardDiff.seed!(
149+
duals::$(esc(ArrayType)){ForwardDiff.Dual{T,V,N}},
150+
x,
151+
index,
152+
seeds::NTuple{N,ForwardDiff.Partials{N,V}},
153+
chunksize=N,
154+
) where {T,V,N}
155+
idxs = collect(
156+
Iterators.take(
157+
Iterators.drop(ForwardDiff.structural_eachindex(duals, x), index - 1),
158+
chunksize,
159+
),
160+
)
161+
duals[idxs] .=
162+
ForwardDiff.Dual{
163+
T,V,N
164+
}.(view(x, idxs), getindex.(Ref(seeds), 1:length(idxs)))
165+
return duals
166+
end
167+
168+
# Overloaded extract_gradient! methods
169+
function ForwardDiff.extract_gradient!(
170+
::Type{T}, result::$(esc(ArrayType)), dual::ForwardDiff.Dual
171+
) where {T}
172+
fn = PartialsFn{T}(dual)
173+
idxs = collect(
174+
Iterators.take(
175+
ForwardDiff.structural_eachindex(result), ForwardDiff.npartials(dual)
176+
),
177+
)
178+
result[idxs] .= fn.(1:length(idxs))
179+
return result
180+
end
181+
182+
function ForwardDiff.extract_gradient_chunk!(
183+
::Type{T}, result::$(esc(ArrayType)), dual, index, chunksize
184+
) where {T}
185+
fn = PartialsFn{T}(dual)
186+
idxs = collect(
187+
Iterators.take(
188+
Iterators.drop(ForwardDiff.structural_eachindex(result), index - 1),
189+
chunksize,
190+
),
191+
)
192+
result[idxs] .= fn.(1:length(idxs))
193+
return result
194+
end
195+
end
196+
end
197+
198+
@static if pkgversion(ForwardDiff) v"1.0.1"
199+
# Apply overloads for GPU arrays
200+
@define_forwarddiff_gpu_overloads AnyGPUArray
201+
end

test/qa_tests.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,21 @@
99
exclude=[
1010
ForwardDiff.jacobian,
1111
ForwardDiff.gradient,
12+
ForwardDiff.extract_gradient_chunk!,
1213
Lux.AutoDiffInternalImpl.batched_jacobian,
1314
Lux.AutoDiffInternalImpl.jacobian_vector_product,
1415
Lux.AutoDiffInternalImpl.jacobian_vector_product_impl,
1516
],
1617
)
17-
Aqua.test_piracies(Lux; treat_as_own=[Lux.outputsize])
18+
Aqua.test_piracies(
19+
Lux;
20+
treat_as_own=[
21+
Lux.outputsize,
22+
ForwardDiff.extract_gradient_chunk!,
23+
ForwardDiff.extract_gradient!,
24+
ForwardDiff.seed!,
25+
],
26+
)
1827
Aqua.test_unbound_args(Lux; broken=true)
1928
end
2029

0 commit comments

Comments
 (0)