Skip to content

Commit 8c4405c

Browse files
test: update tests for enzyme (#1552)
* test: update tests for enzyme * Update src/helpers/training.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 584ca77 commit 8c4405c

File tree

2 files changed

+24
-17
lines changed

2 files changed

+24
-17
lines changed

src/helpers/training.jl

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ function maybe_wrap_adtype(
358358
Enzyme.jl (`AutoEnzyme`)."))
359359
end
360360

361-
function generate_wrappers(::F, m, ps, st, data, ::False) where {F}
361+
function generate_wrappers(::F, m, ps, st, data, ::False, ::StaticBool) where {F}
362362
@warn "Detected function wrapper generation with function being updated between calls. \
363363
This will generate type-unstable code. A possible reason for this is \
364364
`TrainState` was compiled (first call to `compute_gradients`) with function \
@@ -367,16 +367,31 @@ function generate_wrappers(::F, m, ps, st, data, ::False) where {F}
367367
return Ref{Any}(), Ref{NamedTuple}()
368368
end
369369

370+
function generate_wrappers(
371+
objective_function::F, m, ps, st, data, ::True, ::False
372+
) where {F}
373+
_, stₙ, statsₙ = objective_function(m, ps, st, data)
374+
return Ref{typeof(stₙ)}(stₙ), Ref{NamedTuple}() # State type is not preserved
375+
end
376+
370377
# Run the code when trying to compile the function for the first time.
371-
function generate_wrappers(objective_function::F, m, ps, st, data, ::True) where {F}
378+
function generate_wrappers(objective_function::F, m, ps, st, data, ::True, ::True) where {F}
372379
_, stₙ, statsₙ = objective_function(m, ps, st, data)
373380
return Ref{typeof(stₙ)}(stₙ), Ref{typeof(statsₙ)}(statsₙ)
374381
end
375382

376383
function wrap_objective_function(
377384
objective_function::F, m, ps, st, data, first_try::StaticBool
378385
) where {F}
379-
st_updated, stats = generate_wrappers(objective_function, m, ps, st, data, first_try)
386+
st_updated, stats = generate_wrappers(
387+
objective_function,
388+
m,
389+
ps,
390+
st,
391+
data,
392+
first_try,
393+
static(LuxCore.preserves_state_type(m)),
394+
)
380395

381396
wrapped_objective_function = @closure (model, ps, st, data) -> begin
382397
loss, st_, stats_ = objective_function(model, ps, st, data)

test/helpers/training_tests.jl

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -260,9 +260,7 @@ end
260260

261261
tstate = Training.TrainState(model, ps, st, opt)
262262

263-
_, _, _, tstate_new = @inferred Training.compute_gradients(
264-
AutoEnzyme(), mse, (x, x), tstate
265-
)
263+
_, _, _, tstate_new = Training.compute_gradients(AutoEnzyme(), mse, (x, x), tstate)
266264

267265
@test tstate_new.states !== tstate.states
268266

@@ -271,14 +269,12 @@ end
271269

272270
tstate = Training.TrainState(model, ps, st, opt)
273271

274-
_, _, _, tstate_new = @inferred Training.compute_gradients(
275-
AutoEnzyme(), mse, (x, x), tstate
276-
)
272+
_, _, _, tstate_new = Training.compute_gradients(AutoEnzyme(), mse, (x, x), tstate)
277273

278274
@test @inferred(Training.compute_gradients(AutoEnzyme(), mse, (x, x), tstate_new)) isa
279275
Any
280276

281-
_, _, _, tstate_new2 = @inferred Training.compute_gradients(
277+
_, _, _, tstate_new2 = Training.compute_gradients(
282278
AutoEnzyme(), mse2, (x, x), tstate_new
283279
)
284280
@test hasfield(typeof(tstate_new2.cache.extras), :forward)
@@ -293,9 +289,7 @@ end
293289

294290
tstate = Training.TrainState(model, ps, st, opt)
295291

296-
_, _, _, tstate_new = @inferred Training.compute_gradients(
297-
AutoEnzyme(), mse, (x, x), tstate
298-
)
292+
_, _, _, tstate_new = Training.compute_gradients(AutoEnzyme(), mse, (x, x), tstate)
299293

300294
@test tstate_new.states !== tstate.states
301295

@@ -304,14 +298,12 @@ end
304298

305299
tstate = Training.TrainState(model, ps, st, opt)
306300

307-
_, _, _, tstate_new = @inferred Training.compute_gradients(
308-
AutoEnzyme(), mse, (x, x), tstate
309-
)
301+
_, _, _, tstate_new = Training.compute_gradients(AutoEnzyme(), mse, (x, x), tstate)
310302

311303
@test @inferred(Training.compute_gradients(AutoEnzyme(), mse, (x, x), tstate_new)) isa
312304
Any
313305

314-
_, _, _, tstate_new2 = @inferred Training.compute_gradients(
306+
_, _, _, tstate_new2 = Training.compute_gradients(
315307
AutoEnzyme(), mse2, (x, x), tstate_new
316308
)
317309
@test hasfield(typeof(tstate_new2.cache.extras), :forward)

0 commit comments

Comments
 (0)