154154 @jet __f(z)
155155 end
156156
157- broken_backends = VERSION ≥ v" 1.11-" ? Any[AutoEnzyme()] : []
158-
159157 @testset " Conv" begin
160158 c = Conv((3 , 3 ), 3 => 3 ; init_bias= Lux. ones32)
161159
@@ -165,35 +163,31 @@ end
165163 x = randn(rng, Float32, 3 , 3 , 3 , 1 ) |> aType
166164
167165 @jet wn(x, ps, st)
168- @test_gradients(sumabs2first, wn, x, ps, st; atol= 1.0f-3 , rtol= 1.0f-3 ,
169- broken_backends)
166+ @test_gradients(sumabs2first, wn, x, ps, st; atol= 1.0f-3 , rtol= 1.0f-3 )
170167
171168 wn = WeightNorm(c, (:weight,))
172169 display(wn)
173170 ps, st = Lux. setup(rng, wn) |> dev
174171 x = randn(rng, Float32, 3 , 3 , 3 , 1 ) |> aType
175172
176173 @jet wn(x, ps, st)
177- @test_gradients(sumabs2first, wn, x, ps, st; atol= 1.0f-3 , rtol= 1.0f-3 ,
178- broken_backends)
174+ @test_gradients(sumabs2first, wn, x, ps, st; atol= 1.0f-3 , rtol= 1.0f-3 )
179175
180176 wn = WeightNorm(c, (:weight, :bias), (2 , 2 ))
181177 display(wn)
182178 ps, st = Lux. setup(rng, wn) |> dev
183179 x = randn(rng, Float32, 3 , 3 , 3 , 1 ) |> aType
184180
185181 @jet wn(x, ps, st)
186- @test_gradients(sumabs2first, wn, x, ps, st; atol= 1.0f-3 , rtol= 1.0f-3 ,
187- broken_backends)
182+ @test_gradients(sumabs2first, wn, x, ps, st; atol= 1.0f-3 , rtol= 1.0f-3 )
188183
189184 wn = WeightNorm(c, (:weight,), (2 ,))
190185 display(wn)
191186 ps, st = Lux. setup(rng, wn) |> dev
192187 x = randn(rng, Float32, 3 , 3 , 3 , 1 ) |> aType
193188
194189 @jet wn(x, ps, st)
195- @test_gradients(sumabs2first, wn, x, ps, st; atol= 1.0f-3 , rtol= 1.0f-3 ,
196- broken_backends)
190+ @test_gradients(sumabs2first, wn, x, ps, st; atol= 1.0f-3 , rtol= 1.0f-3 )
197191 end
198192
199193 @testset " Dense" begin
@@ -205,35 +199,31 @@ end
205199 x = randn(rng, Float32, 3 , 1 ) |> aType
206200
207201 @jet wn(x, ps, st)
208- @test_gradients(sumabs2first, wn, x, ps, st; atol= 1.0f-3 , rtol= 1.0f-3 ,
209- broken_backends)
202+ @test_gradients(sumabs2first, wn, x, ps, st; atol= 1.0f-3 , rtol= 1.0f-3 )
210203
211204 wn = WeightNorm(d, (:weight,))
212205 display(wn)
213206 ps, st = Lux. setup(rng, wn) |> dev
214207 x = randn(rng, Float32, 3 , 1 ) |> aType
215208
216209 @jet wn(x, ps, st)
217- @test_gradients(sumabs2first, wn, x, ps, st; atol= 1.0f-3 , rtol= 1.0f-3 ,
218- broken_backends)
210+ @test_gradients(sumabs2first, wn, x, ps, st; atol= 1.0f-3 , rtol= 1.0f-3 )
219211
220212 wn = WeightNorm(d, (:weight, :bias), (2 , 2 ))
221213 display(wn)
222214 ps, st = Lux. setup(rng, wn) |> dev
223215 x = randn(rng, Float32, 3 , 1 ) |> aType
224216
225217 @jet wn(x, ps, st)
226- @test_gradients(sumabs2first, wn, x, ps, st; atol= 1.0f-3 , rtol= 1.0f-3 ,
227- broken_backends)
218+ @test_gradients(sumabs2first, wn, x, ps, st; atol= 1.0f-3 , rtol= 1.0f-3 )
228219
229220 wn = WeightNorm(d, (:weight,), (2 ,))
230221 display(wn)
231222 ps, st = Lux. setup(rng, wn) |> dev
232223 x = randn(rng, Float32, 3 , 1 ) |> aType
233224
234225 @jet wn(x, ps, st)
235- @test_gradients(sumabs2first, wn, x, ps, st; atol= 1.0f-3 , rtol= 1.0f-3 ,
236- broken_backends)
226+ @test_gradients(sumabs2first, wn, x, ps, st; atol= 1.0f-3 , rtol= 1.0f-3 )
237227 end
238228
239229 # See https://github.com/LuxDL/Lux.jl/issues/95
0 commit comments