Skip to content

Commit

Permalink
Update description of trainable in "advanced.md" (#2289)
Browse files Browse the repository at this point in the history
* Update advanced.md

* Update docs/src/models/advanced.md

Co-authored-by: Kyle Daruwalla <[email protected]>

---------

Co-authored-by: Kyle Daruwalla <[email protected]>
  • Loading branch information
mcabbott and darsnack authored Jul 16, 2023
1 parent 8c23af3 commit e439ae7
Showing 1 changed file with 18 additions and 14 deletions.
32 changes: 18 additions & 14 deletions docs/src/models/advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,34 +36,38 @@ For an intro to Flux and automatic differentiation, see this [tutorial](https://

Taking reference from our example `Affine` layer from the [basics](@ref man-basics).

By default all the fields in the `Affine` type are collected as its parameters, however, in some cases it may be desired to hold other metadata in our "layers" that may not be needed for training, and are hence supposed to be ignored while the parameters are collected. With Flux, it is possible to mark the fields of our layers that are trainable in two ways.

The first way of achieving this is through overloading the `trainable` function.
By default all the fields in the `Affine` type are collected as its parameters, however, in some cases it may be desired to hold other metadata in our "layers" that may not be needed for training, and are hence supposed to be ignored while the parameters are collected. With Flux, the way to mark some fields of our layer as trainable is through overloading the `trainable` function:

```julia-repl
julia> @functor Affine
julia> Flux.@functor Affine
julia> a = Affine(rand(3,3), rand(3))
Affine{Array{Float64,2},Array{Float64,1}}([0.66722 0.774872 0.249809; 0.843321 0.403843 0.429232; 0.683525 0.662455 0.065297], [0.42394, 0.0170927, 0.544955])
julia> a = Affine(Float32[1 2; 3 4; 5 6], Float32[7, 8, 9])
Affine(Float32[1.0 2.0; 3.0 4.0; 5.0 6.0], Float32[7.0, 8.0, 9.0])
julia> Flux.params(a) # default behavior
Params([[0.66722 0.774872 0.249809; 0.843321 0.403843 0.429232; 0.683525 0.662455 0.065297], [0.42394, 0.0170927, 0.544955]])
Params([Float32[1.0 2.0; 3.0 4.0; 5.0 6.0], Float32[7.0, 8.0, 9.0]])
julia> Flux.trainable(a::Affine) = (a.W,)
julia> Flux.trainable(a::Affine) = (; a.W) # returns a NamedTuple using the field's name
julia> Flux.params(a)
Params([[0.66722 0.774872 0.249809; 0.843321 0.403843 0.429232; 0.683525 0.662455 0.065297]])
Params([Float32[1.0 2.0; 3.0 4.0; 5.0 6.0]])
```

Only the fields returned by `trainable` will be collected as trainable parameters of the layer when calling `Flux.params`.
Only the fields returned by `trainable` will be collected as trainable parameters of the layer when calling `Flux.params`, and only these fields will be seen by `Flux.setup` and `Flux.update!` for training. But all fields wil be seen by `gpu` and similar functions, for example:

```julia-repl
julia> a |> f16
Affine(Float16[1.0 2.0; 3.0 4.0; 5.0 6.0], Float16[7.0, 8.0, 9.0])
```

Another way of achieving this is through the `@functor` macro directly. Here, we can mark the fields we are interested in by grouping them in the second argument:
Note that there is no need to overload `trainable` to hide fields which do not contain trainable parameters. (For example, activation functions, or Boolean flags.) These are always ignored by `params` and by training:

```julia
Flux.@functor Affine (W,)
```julia-repl
julia> Flux.params(Affine(true, [10, 11, 12.0]))
Params([])
```

However, doing this requires the `struct` to have a corresponding constructor that accepts those parameters.
It is also possible to further restrict what fields are seen by writing `@functor Affine (W,)`. However, this is not recommended. This requires the `struct` to have a corresponding constructor that accepts only `W` as an argument, and the ignored fields will not be seen by functions like `gpu` (which is usually undesired).

## Freezing Layer Parameters

Expand Down

0 comments on commit e439ae7

Please sign in to comment.