11using Flux, MLDatasets
2- using Flux: onehotbatch, onecold, DataLoader, Optimiser
2+ using Flux: onehotbatch, onecold, DataLoader, flatten, OptimiserChain
33using BSON: @save ,@load
4+ Flux. _old_to_new (rule:: ClipNorm ) = Flux. Optimisers. ClipNorm (rule. thresh) # wrong in Flux 0.13.9
45
56
67function ConvMixer (in_channels, kernel_size, patch_size, dim, depth, N_classes)
@@ -77,7 +78,7 @@ function train(n_epochs=100)
7778
7879 # params: warning, the training can be long with these params
7980 train_loader, test_loader = get_data (128 )
80- η = 3e -4
81+ η = 3f -4
8182 in_channel = 3
8283 patch_size = 2
8384 kernel_size = 7
@@ -103,18 +104,18 @@ function train(n_epochs=100)
103104
104105 model = ConvMixer (in_channel, kernel_size, patch_size, dim, depth, 10 ) |> device
105106
106- ps = params (model)
107- opt = Optimiser (
107+ opt = OptimiserChain (
108108 WeightDecay (1f-3 ),
109- ClipNorm (1.0 ),
110- ADAM (η)
109+ ClipNorm (1f0 ),
110+ Adam (η),
111111 )
112+ state = Flux. setup (opt, model)
112113
113114 for epoch in 1 : n_epochs
114115 for (x,y) in train_loader
115116 x,y = x|> device, y|> device
116- gr = gradient (() -> Flux. logitcrossentropy (model (x), y, agg= sum), ps )
117- Flux. Optimise. update! (opt, ps, gr )
117+ grads = gradient (m -> Flux. logitcrossentropy (m (x), y, agg= sum), model )
118+ Flux. Optimise. update! (state, model, grads[ 1 ] )
118119 end
119120
120121 # logging
0 commit comments