@@ -39,16 +39,18 @@ function run_dense_testing(Tw, Tx, M, N, hasbias, activation, aType, mode, ongpu
3939 y_zyg = fused_dense_bias_activation(activation, w, x, bias)
4040 @test y_simple ≈ y_zyg atol = atol rtol = rtol
4141
42- _, ∂w_true, ∂x_true, ∂b_true = Zygote. gradient(
43- sum ∘ dense_simple, activation, w, x, bias
44- )
45- _, ∂w_zyg, ∂x_zyg, ∂b_zyg = Zygote. gradient(
46- sum ∘ fused_dense_bias_activation, activation, w, x, bias
47- )
48- @test ∂w_true ≈ ∂w_zyg atol = atol rtol = rtol
49- @test ∂x_true ≈ ∂x_zyg atol = atol rtol = rtol
50- if bias != = nothing
51- @test ∂b_true ≈ ∂b_zyg atol = atol rtol = rtol
42+ if LuxTestUtils. ZYGOTE_TESTING_ENABLED[]
43+ _, ∂w_true, ∂x_true, ∂b_true = Zygote. gradient(
44+ sum ∘ dense_simple, activation, w, x, bias
45+ )
46+ _, ∂w_zyg, ∂x_zyg, ∂b_zyg = Zygote. gradient(
47+ sum ∘ fused_dense_bias_activation, activation, w, x, bias
48+ )
49+ @test ∂w_true ≈ ∂w_zyg atol = atol rtol = rtol
50+ @test ∂x_true ≈ ∂x_zyg atol = atol rtol = rtol
51+ if bias != = nothing
52+ @test ∂b_true ≈ ∂b_zyg atol = atol rtol = rtol
53+ end
5254 end
5355end
5456
@@ -201,35 +203,39 @@ end
201203 b_enz,
202204 )
203205
204- _, pb_f = Zygote. pullback(fused_dense_bias_activation, act, weight, x, b)
205- _, dweight_zyg, dx_zyg, db_zyg = pb_f(dy)
206+ if LuxTestUtils. ZYGOTE_TESTING_ENABLED[]
207+ _, pb_f = Zygote. pullback(fused_dense_bias_activation, act, weight, x, b)
208+ _, dweight_zyg, dx_zyg, db_zyg = pb_f(dy)
206209
207- @test dweight ≈ dweight_zyg atol = 1.0e-3 rtol = 1.0e-3
208- @test dx ≈ dx_zyg atol = 1.0e-3 rtol = 1.0e-3
209- if hasbias
210- @test db ≈ db_zyg atol = 1.0e-3 rtol = 1.0e-3
210+ @test dweight ≈ dweight_zyg atol = 1.0e-3 rtol = 1.0e-3
211+ @test dx ≈ dx_zyg atol = 1.0e-3 rtol = 1.0e-3
212+ if hasbias
213+ @test db ≈ db_zyg atol = 1.0e-3 rtol = 1.0e-3
214+ end
211215 end
212216
213217 (act === identity && hasbias) || continue
214218
215- dweight .= 0
216- dx .= 0
217- db .= 0
218- Enzyme. autodiff(
219- Reverse,
220- matmuladd!,
221- Duplicated(y, copy(dy)),
222- Duplicated(weight, dweight),
223- Duplicated(x, dx),
224- b_enz,
225- )
226-
227- _, pb_f = Zygote. pullback(LuxLib. Impl. matmuladd, weight, x, b)
228- dweight_zyg, dx_zyg, db_zyg = pb_f(dy)
229-
230- @test dweight ≈ dweight_zyg atol = 1.0e-3 rtol = 1.0e-3
231- @test dx ≈ dx_zyg atol = 1.0e-3 rtol = 1.0e-3
232- @test db ≈ db_zyg atol = 1.0e-3 rtol = 1.0e-3
219+ if LuxTestUtils. ZYGOTE_TESTING_ENABLED[]
220+ dweight .= 0
221+ dx .= 0
222+ db .= 0
223+ Enzyme. autodiff(
224+ Reverse,
225+ matmuladd!,
226+ Duplicated(y, copy(dy)),
227+ Duplicated(weight, dweight),
228+ Duplicated(x, dx),
229+ b_enz,
230+ )
231+
232+ _, pb_f = Zygote. pullback(LuxLib. Impl. matmuladd, weight, x, b)
233+ dweight_zyg, dx_zyg, db_zyg = pb_f(dy)
234+
235+ @test dweight ≈ dweight_zyg atol = 1.0e-3 rtol = 1.0e-3
236+ @test dx ≈ dx_zyg atol = 1.0e-3 rtol = 1.0e-3
237+ @test db ≈ db_zyg atol = 1.0e-3 rtol = 1.0e-3
238+ end
233239 end
234240 end
235241end
0 commit comments