Skip to content

Commit 80f3384

Browse files
committed
upgrade to explicit
1 parent 7aac855 commit 80f3384

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

vision/convmixer_cifar10/convmixer.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using Flux, MLDatasets
2-
using Flux: onehotbatch, onecold, DataLoader, Optimiser
2+
using Flux: onehotbatch, onecold, DataLoader, flatten, OptimiserChain
33
using BSON:@save,@load
4+
Flux._old_to_new(rule::ClipNorm) = Flux.Optimisers.ClipNorm(rule.thresh) # wrong in Flux 0.13.9
45

56

67
function ConvMixer(in_channels, kernel_size, patch_size, dim, depth, N_classes)
@@ -77,7 +78,7 @@ function train(n_epochs=100)
7778

7879
#params: warning, the training can be long with these params
7980
train_loader, test_loader = get_data(128)
80-
η = 3e-4
81+
η = 3f-4
8182
in_channel = 3
8283
patch_size = 2
8384
kernel_size = 7
@@ -103,18 +104,18 @@ function train(n_epochs=100)
103104

104105
model = ConvMixer(in_channel, kernel_size, patch_size, dim, depth, 10) |> device
105106

106-
ps = params(model)
107-
opt = Optimiser(
107+
opt = OptimiserChain(
108108
WeightDecay(1f-3),
109-
ClipNorm(1.0),
110-
ADAM(η)
109+
ClipNorm(1f0),
110+
Adam(η),
111111
)
112+
state = Flux.setup(opt, model)
112113

113114
for epoch in 1:n_epochs
114115
for (x,y) in train_loader
115116
x,y = x|>device, y|>device
116-
gr = gradient(()->Flux.logitcrossentropy(model(x), y, agg=sum), ps)
117-
Flux.Optimise.update!(opt, ps, gr)
117+
grads = gradient(m->Flux.logitcrossentropy(m(x), y, agg=sum), model)
118+
Flux.Optimise.update!(state, model, grads[1])
118119
end
119120

120121
#logging

0 commit comments

Comments
 (0)