-
-
Couldn't load subscription status.
- Fork 332
Open
Description
Motivation and description
Description:
The current train.jl script manually manages optimizer state and uses the low-level Optimisers.jl API (Optimisers.setup/Optimisers.update! or explicit params(model) loops). We should refactor it to leverage Flux's unified Flux.setup and Flux.update! API introduced in Flux 0.13+, for a cleaner and more idiomatic training loop.
Current Behavior:
- Optimizers are initialized and updated using the explicit Optimisers.jl API:
opt = ADAM(lr) ps = params(model) st = Optimisers.setup(opt, ps) ... st, _ = Optimisers.update!(st, ps, grads)
- Training loops manually handle gradient collection and state updates.
Desired Behavior:
- Use Flux's unified optimizer API:
opt_state = Flux.setup(Flux.Optimise.Adam(lr), model) # inside training loop loss, grads = Flux.withgradient(model) do m loss_fn(m(x)) end Flux.update!(opt_state, model, grads[1])
- Simplify
train_discriminator!andtrain_generator!to useFlux.setup/Flux.update!instead of explicit parameter/state management.
Proposed Changes:
- Replace all explicit
Optimisers.setupandOptimisers.update!calls withFlux.setupandFlux.update!on theChainmodels. - Remove manual
params(model)and state-tracking variables where no longer needed.
Acceptance Criteria:
- Code compiles and runs without deprecation warnings under Flux 0.14+.
- Discriminator and generator training functions use
Flux.setupandFlux.update!exclusively. - Existing functionality and performance are preserved.
References:
- Flux unified training API: https://fluxml.ai/Flux.jl/stable/guide/training/training/
Metadata
Metadata
Assignees
Labels
No labels