Skip to content

Commit 63d3434

Browse files
authored
feat: more nested AD rules (#1151)
* feat: softmax and logsoftmax jvp rules * feat: add pooling rules * test: logsoftmax and softmax forwarddiff rules * fix: patch meanpool * test: more tests fixed
1 parent 3c3a432 commit 63d3434

File tree

7 files changed

+158
-29
lines changed

7 files changed

+158
-29
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ ComponentArrays = "0.15.18"
8989
ConcreteStructs = "0.2.3"
9090
DispatchDoctor = "0.4.12"
9191
Enzyme = "0.13.16"
92-
EnzymeCore = "0.8.6"
92+
EnzymeCore = "0.8.8"
9393
FastClosures = "0.3.2"
9494
Flux = "0.15, 0.16"
9595
ForwardDiff = "0.10.36"

lib/LuxLib/Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LuxLib"
22
uuid = "82251201-b29d-42c6-8e01-566dec8acb11"
33
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
4-
version = "1.3.11"
4+
version = "1.4.0"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
@@ -66,7 +66,7 @@ Compat = "4.16"
6666
CpuId = "0.3"
6767
DispatchDoctor = "0.4.12"
6868
Enzyme = "0.13.16"
69-
EnzymeCore = "0.8.6"
69+
EnzymeCore = "0.8.8"
7070
FastClosures = "0.3.2"
7171
ForwardDiff = "0.10.36"
7272
Hwloc = "3.2"

lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,22 @@
11
module LuxLibCUDAExt
22

3-
using CUDA: CUDA, CUBLAS, StridedCuMatrix, StridedCuVector, CuPtr
3+
using CUDA: CUDA, CUBLAS, CuArray, StridedCuMatrix, StridedCuVector, CuPtr
4+
using ForwardDiff: ForwardDiff
45
using LinearAlgebra: LinearAlgebra, Transpose, Adjoint
5-
using LuxLib: LuxLib, Optional
6+
using LuxLib: LuxLib, Impl, Optional
67
using LuxLib.Utils: ofeltype_array
78
using NNlib: NNlib
89
using Static: True, False
910

11+
# Hacky Type Piracy for ForwardDiff rules
12+
for op in (:logsoftmax, :softmax)
13+
dual_op = Symbol(op, :_dual)
14+
@eval function NNlib.$(op)(
15+
x::CuArray{<:ForwardDiff.Dual{Tag, T, P}}; dims=1) where {Tag, T, P}
16+
return Impl.$(dual_op)(x; dims)
17+
end
18+
end
19+
1020
# Low level functions
1121
include("cublaslt.jl")
1222

lib/LuxLib/src/impl/forward_diff.jl

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
for op in [:conv, :depthwiseconv, :∇conv_data, :∇conv_filter]
1+
for op in (:conv, :depthwiseconv, :∇conv_data, :∇conv_filter)
22
patched_op = op !== :depthwiseconv ? eval(op) : getfield(NNlib, op)
33

44
@eval function NNlib.$(op)(x1::AbstractArray{<:ForwardDiff.Dual{Tag, V, P}, N},
@@ -48,3 +48,82 @@ for op in [:conv, :depthwiseconv, :∇conv_data, :∇conv_filter]
4848
return ForwardDiff.Dual{Tag, eltype(y), P}.(y, partials)
4949
end
5050
end
51+
52+
for op in (:logsoftmax, :softmax)
53+
dual_op = Symbol(op, :_dual)
54+
@eval function NNlib.$(op)(
55+
x::AbstractArray{<:ForwardDiff.Dual{Tag, T, P}}; dims=1) where {Tag, T, P}
56+
return Impl.$(dual_op)(x; dims)
57+
end
58+
end
59+
60+
function softmax_dual(
61+
x::AbstractArray{<:ForwardDiff.Dual{Tag, T, P}}; dims=1) where {Tag, T, P}
62+
value_fn(x) = ForwardDiff.value(Tag, x)
63+
partial_fn(x, i) = ForwardDiff.partials(Tag, x, i)
64+
65+
x_data = value_fn.(x)
66+
67+
y = NNlib.softmax(x_data; dims)
68+
dysᵢ = ntuple(P) do i
69+
v = partial_fn.(x, i)
70+
return y .* (v .- sum(y .* v; dims))
71+
end
72+
73+
partials = ForwardDiff.Partials.(tuple.(dysᵢ...))
74+
return ForwardDiff.Dual{Tag, eltype(y), P}.(y, partials)
75+
end
76+
77+
function logsoftmax_dual(
78+
x::AbstractArray{<:ForwardDiff.Dual{Tag, T, P}}; dims=1) where {Tag, T, P}
79+
value_fn(x) = ForwardDiff.value(Tag, x)
80+
partial_fn(x, i) = ForwardDiff.partials(Tag, x, i)
81+
82+
x_data = value_fn.(x)
83+
84+
y = NNlib.softmax(x_data; dims)
85+
dysᵢ = ntuple(P) do i
86+
v = partial_fn.(x, i)
87+
return v .- sum(y .* v; dims)
88+
end
89+
90+
partials = ForwardDiff.Partials.(tuple.(dysᵢ...))
91+
return ForwardDiff.Dual{Tag, eltype(y), P}.(y, partials)
92+
end
93+
94+
@eval function NNlib.meanpool(
95+
x::AbstractArray{<:ForwardDiff.Dual{Tag, T, P}}, pdims::NNlib.PoolDims;
96+
kwargs...) where {Tag, T, P}
97+
value_fn(x) = ForwardDiff.value(Tag, x)
98+
partial_fn(x, i) = ForwardDiff.partials(Tag, x, i)
99+
100+
y = NNlib.meanpool(value_fn.(x), pdims; kwargs...)
101+
dysᵢ = ntuple(P) do i
102+
return NNlib.meanpool(partial_fn.(x, i), pdims; kwargs...)
103+
end
104+
105+
partials = ForwardDiff.Partials.(tuple.(dysᵢ...))
106+
return ForwardDiff.Dual{Tag, eltype(y), P}.(y, partials)
107+
end
108+
109+
function NNlib.∇meanpool(
110+
dy::AbstractArray{<:ForwardDiff.Dual{Tag, T1, P}},
111+
y::AbstractArray{<:ForwardDiff.Dual{Tag, T1, P}},
112+
x::AbstractArray{<:ForwardDiff.Dual{Tag, T2, P}},
113+
pdims::NNlib.PoolDims; kwargs...) where {Tag, T1, T2, P}
114+
value_fn(x) = ForwardDiff.value(Tag, x)
115+
partial_fn(x, i) = ForwardDiff.partials(Tag, x, i)
116+
117+
dy_data, y_data, x_data = value_fn.(dy), value_fn.(y), value_fn.(x)
118+
119+
dx = NNlib.∇meanpool(dy_data, y_data, x_data, pdims; kwargs...)
120+
dysᵢ = ntuple(P) do i
121+
∇y₁ = NNlib.∇meanpool(partial_fn.(dy, i), y_data, x_data, pdims; kwargs...)
122+
∇y₂ = NNlib.∇meanpool(dy_data, partial_fn.(y, i), x_data, pdims; kwargs...)
123+
@. ∇y₁ = (∇y₁ + ∇y₂) * partial_fn(x, i)
124+
return ∇y₁
125+
end
126+
127+
partials = ForwardDiff.Partials.(tuple.(dysᵢ...))
128+
return ForwardDiff.Dual{Tag, eltype(dx), P}.(dx, partials)
129+
end

lib/LuxLib/test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ BenchmarkTools = "1.5"
3939
ChainRulesCore = "1.24"
4040
ComponentArrays = "0.15.18"
4141
Enzyme = "0.13.16"
42-
EnzymeCore = "0.8.6"
42+
EnzymeCore = "0.8.8"
4343
ExplicitImports = "1.9.0"
4444
ForwardDiff = "0.10.36"
4545
Hwloc = "3.2"

lib/LuxLib/test/others/forwarddiff_tests.jl

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,20 @@
2626

2727
function test_jvp_computation(f::F, x, u, ongpu, nested=false) where {F}
2828
jvp₁ = jvp_forwarddiff(f, x, u)
29+
2930
if !(x isa ComponentArray && ongpu)
3031
# ComponentArray + ForwardDiff on GPU don't play nice
31-
jvp₂ = jvp_forwarddiff_concrete(f, x, u)
32-
@test check_approx(jvp₁, jvp₂; atol=1e-5, rtol=1e-5)
32+
@testset "JVP ForwardDiff Concrete" begin
33+
jvp₂ = jvp_forwarddiff_concrete(f, x, u)
34+
@test check_approx(jvp₁, jvp₂; atol=1e-5, rtol=1e-5)
35+
end
3336
end
3437

3538
if !nested
36-
jvp₃ = jvp_zygote(f, x, u)
37-
@test check_approx(jvp₁, jvp₃; atol=1e-5, rtol=1e-5)
39+
@testset "JVP Zygote" begin
40+
jvp₃ = jvp_zygote(f, x, u)
41+
@test check_approx(jvp₁, jvp₃; atol=1e-5, rtol=1e-5)
42+
end
3843
end
3944
end
4045

@@ -89,6 +94,51 @@
8994
true)
9095
end
9196
end
97+
98+
@testset for op in (logsoftmax, softmax)
99+
@testset for (input_dim, dim) in zip(
100+
(
101+
(2, 3), (2, 3), (2, 3, 4, 5),
102+
(2, 3, 4, 5), (2, 3, 4, 5), (2, 3, 4, 5)
103+
),
104+
(1, 2, 1, 2, 3, 4)
105+
)
106+
x = randn(Float32, input_dim) |> aType
107+
u = randn(Float32, input_dim) |> aType
108+
109+
test_jvp_computation(x -> op(x; dims=dim), x, u, ongpu)
110+
test_jvp_computation(
111+
x -> op(x; dims=dim), ComponentArray(; x), u, ongpu)
112+
113+
test_jvp_computation(
114+
x -> only(Zygote.gradient(x -> sum(op(x; dims=dim)), x)),
115+
x, u, ongpu, true
116+
)
117+
end
118+
end
119+
120+
@testset for op in (meanpool,)
121+
@testset for (input_dim, kernel_size, stride, pad) in (
122+
((8, 3, 2), (4,), (2,), (0,)),
123+
((8, 3, 2), (4,), (3,), (0,)),
124+
((8, 3, 2), (4,), (3,), (1,)),
125+
((8, 8, 3, 2), (4, 4), (2, 2), (0, 0)),
126+
((8, 8, 3, 2), (4, 4), (3, 3), (0, 0)),
127+
((8, 8, 3, 2), (4, 4), (3, 3), (1, 1))
128+
)
129+
x = randn(Float32, input_dim) |> aType
130+
u = randn(Float32, input_dim) |> aType
131+
132+
test_jvp_computation(
133+
x -> op(x, kernel_size; stride, pad), x, u, ongpu)
134+
135+
test_jvp_computation(
136+
x -> only(Zygote.gradient(
137+
x -> sum(op(x, kernel_size; stride, pad)), x)),
138+
x, u, ongpu, true
139+
)
140+
end
141+
end
92142
end
93143
end
94144

test/layers/normalize_tests.jl

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,6 @@ end
154154
@jet __f(z)
155155
end
156156

157-
broken_backends = VERSION v"1.11-" ? Any[AutoEnzyme()] : []
158-
159157
@testset "Conv" begin
160158
c = Conv((3, 3), 3 => 3; init_bias=Lux.ones32)
161159

@@ -165,35 +163,31 @@ end
165163
x = randn(rng, Float32, 3, 3, 3, 1) |> aType
166164

167165
@jet wn(x, ps, st)
168-
@test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3,
169-
broken_backends)
166+
@test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3)
170167

171168
wn = WeightNorm(c, (:weight,))
172169
display(wn)
173170
ps, st = Lux.setup(rng, wn) |> dev
174171
x = randn(rng, Float32, 3, 3, 3, 1) |> aType
175172

176173
@jet wn(x, ps, st)
177-
@test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3,
178-
broken_backends)
174+
@test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3)
179175

180176
wn = WeightNorm(c, (:weight, :bias), (2, 2))
181177
display(wn)
182178
ps, st = Lux.setup(rng, wn) |> dev
183179
x = randn(rng, Float32, 3, 3, 3, 1) |> aType
184180

185181
@jet wn(x, ps, st)
186-
@test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3,
187-
broken_backends)
182+
@test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3)
188183

189184
wn = WeightNorm(c, (:weight,), (2,))
190185
display(wn)
191186
ps, st = Lux.setup(rng, wn) |> dev
192187
x = randn(rng, Float32, 3, 3, 3, 1) |> aType
193188

194189
@jet wn(x, ps, st)
195-
@test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3,
196-
broken_backends)
190+
@test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3)
197191
end
198192

199193
@testset "Dense" begin
@@ -205,35 +199,31 @@ end
205199
x = randn(rng, Float32, 3, 1) |> aType
206200

207201
@jet wn(x, ps, st)
208-
@test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3,
209-
broken_backends)
202+
@test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3)
210203

211204
wn = WeightNorm(d, (:weight,))
212205
display(wn)
213206
ps, st = Lux.setup(rng, wn) |> dev
214207
x = randn(rng, Float32, 3, 1) |> aType
215208

216209
@jet wn(x, ps, st)
217-
@test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3,
218-
broken_backends)
210+
@test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3)
219211

220212
wn = WeightNorm(d, (:weight, :bias), (2, 2))
221213
display(wn)
222214
ps, st = Lux.setup(rng, wn) |> dev
223215
x = randn(rng, Float32, 3, 1) |> aType
224216

225217
@jet wn(x, ps, st)
226-
@test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3,
227-
broken_backends)
218+
@test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3)
228219

229220
wn = WeightNorm(d, (:weight,), (2,))
230221
display(wn)
231222
ps, st = Lux.setup(rng, wn) |> dev
232223
x = randn(rng, Float32, 3, 1) |> aType
233224

234225
@jet wn(x, ps, st)
235-
@test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3,
236-
broken_backends)
226+
@test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3)
237227
end
238228

239229
# See https://github.com/LuxDL/Lux.jl/issues/95

0 commit comments

Comments
 (0)