Skip to content

Support for Transfer-Learning/Layer-Freezing #150

@JoshuaBillson

Description

@JoshuaBillson

Motivation and description

A common practice in machine learning is to take a pre-trained model and fine-tune it on a particular dataset. This typically involves freezing the weights in some layers while fitting the output layer(s) on the new data.

Unfortunately, this functionally appears to be incompatible with the current implementation of the ToDevice callback based on the following code:

function on(::EpochBegin, ::Phase, cb::ToDevice, learner)
    model!(learner, cb.movemodelfn(learner.model))
end

function model!(learner, model)
    learner.model = model
    learner.params = setupoptimstate(model, learner.optimizer)
end

setupoptimstate(model, ::Flux.Optimise.AbstractOptimiser) = Flux.params(model)

setupoptimstate(model, optim) = Optimisers.setup(optim, model)

This essentially means that learner.params is set to the parameters of the full model at the start of each epoch. Thus, even if we try to freeze the layers manually with Flux.freeze!(learner.params.layers[1:end-1]), this will be undone by ToDevice.

Possible Implementation

One solution that would work with Flux's new explicit optimizers would be to create a callback to freeze layers after ToDevice is executed. An example is given below:

mutable struct LayerFreezing{F} <: FluxTraining.Callback
    accessor::F
end

function FluxTraining.stateaccess(scheduler::LayerFreezing)
    return (;params = FluxTraining.Write())
end

function FluxTraining.on(
    event::FluxTraining.EpochBegin, 
    phase::FluxTraining.AbstractTrainingPhase, 
    freezer::LayerFreezing, 
    learner)
    Flux.freeze!(freezer.accessor(learner.params))
end

FluxTraining.runafter(::LayerFreezing) = (FluxTraining.ToDevice,)

However, perhaps we should consider whether it's necessary for ToDevice to move the model to the GPU at the start of every epoch. Maybe we could extend the Callback interface to allow for some one-time setup code to run before the first epoch is executed?

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or requesthelp wantedExtra attention is needed

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions