Skip to content

Refactor DCGAN training script to use Flux.setup/update! unified API #411

@josemanuel22

Description

@josemanuel22

Motivation and description

Description:
The current train.jl script manually manages optimizer state and uses the low-level Optimisers.jl API (Optimisers.setup/Optimisers.update! or explicit params(model) loops). We should refactor it to leverage Flux's unified Flux.setup and Flux.update! API introduced in Flux 0.13+, for a cleaner and more idiomatic training loop.

Current Behavior:

  • Optimizers are initialized and updated using the explicit Optimisers.jl API:
    opt = ADAM(lr)
    ps  = params(model)
    st  = Optimisers.setup(opt, ps)
    ...
    st, _ = Optimisers.update!(st, ps, grads)
  • Training loops manually handle gradient collection and state updates.

Desired Behavior:

  • Use Flux's unified optimizer API:
    opt_state = Flux.setup(Flux.Optimise.Adam(lr), model)
    
    # inside training loop
    loss, grads = Flux.withgradient(model) do m
      loss_fn(m(x))
    end
    Flux.update!(opt_state, model, grads[1])
  • Simplify train_discriminator! and train_generator! to use Flux.setup/Flux.update! instead of explicit parameter/state management.

Proposed Changes:

  1. Replace all explicit Optimisers.setup and Optimisers.update! calls with Flux.setup and Flux.update! on the Chain models.
  2. Remove manual params(model) and state-tracking variables where no longer needed.

Acceptance Criteria:

  • Code compiles and runs without deprecation warnings under Flux 0.14+.
  • Discriminator and generator training functions use Flux.setup and Flux.update! exclusively.
  • Existing functionality and performance are preserved.

References:

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions