Skip to content

Commit db49bbb

Browse files
committed
tweaks
1 parent 8c9132d commit db49bbb

File tree

1 file changed

+19
-15
lines changed

1 file changed

+19
-15
lines changed

vision/convmixer_cifar10/convmixer.jl

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
using Flux, MLDatasets
22
using Flux: onehotbatch, onecold, DataLoader, flatten, OptimiserChain
33
using BSON:@save,@load
4-
Flux._old_to_new(rule::ClipNorm) = Flux.Optimisers.ClipNorm(rule.thresh) # wrong in Flux 0.13.9
4+
ENV["DATADEPS_ALWAYS_ACCEPT"] = "true"
55

6+
# wrong in Flux 0.13.9
7+
Flux._old_to_new(rule::ClipNorm) = Flux.Optimisers.ClipNorm(rule.thresh)
8+
9+
# Also, quick test of train(epochs=10, images=128) shows increasing loss, not sure why.
610

711
function ConvMixer(in_channels, kernel_size, patch_size, dim, depth, N_classes)
812
f = Chain(
@@ -22,19 +26,19 @@ function ConvMixer(in_channels, kernel_size, patch_size, dim, depth, N_classes)
2226
return f
2327
end
2428

25-
function get_data(batchsize; dataset = MLDatasets.CIFAR10, idxs = nothing)
26-
"""
27-
idxs=nothing gives the full dataset, otherwise (for testing purposes) only the 1:idxs elements of the train set are given.
28-
"""
29-
ENV["DATADEPS_ALWAYS_ACCEPT"] = "true"
29+
"""
30+
By default gives the full dataset, keyword images gives (for testing purposes)
31+
only the 1:images elements of the train set.
32+
"""
33+
function get_data(batchsize; dataset = MLDatasets.CIFAR10, images = :)
3034

3135
# Loading Dataset
32-
if idxs===nothing
36+
if images === (:)
3337
xtrain, ytrain = dataset(:train)[:]
3438
xtest, ytest = dataset(:test)[:]
3539
else
36-
xtrain, ytrain = dataset(:train)[1:idxs]
37-
xtest, ytest = dataset(:test)[1:Int(idxs/10)]
40+
xtrain, ytrain = dataset(:train)[1:images]
41+
xtest, ytest = dataset(:test)[1:(images÷10)]
3842
end
3943

4044
# Reshape Data to comply to Julia's (width, height, channels, batch_size) convention in case there are only 1 channel (eg MNIST)
@@ -74,10 +78,10 @@ function create_loss_function(dataloader, device)
7478
end
7579

7680

77-
function train(n_epochs=100)
81+
function train(; epochs=100, images=:)
7882

7983
#params: warning, the training can be long with these params
80-
train_loader, test_loader = get_data(128)
84+
train_loader, test_loader = get_data(128; images)
8185
η = 3f-4
8286
in_channel = 3
8387
patch_size = 2
@@ -88,8 +92,8 @@ function train(n_epochs=100)
8892
use_cuda = true
8993

9094
#logging the losses
91-
train_save = zeros(n_epochs, 2)
92-
test_save = zeros(n_epochs, 2)
95+
train_save = zeros(epochs, 2)
96+
test_save = zeros(epochs, 2)
9397

9498
if use_cuda
9599
device = gpu
@@ -111,11 +115,11 @@ function train(n_epochs=100)
111115
)
112116
state = Flux.setup(opt, model)
113117

114-
for epoch in 1:n_epochs
118+
for epoch in 1:epochs
115119
for (x,y) in train_loader
116120
x,y = x|>device, y|>device
117121
grads = gradient(m->Flux.logitcrossentropy(m(x), y, agg=sum), model)
118-
Flux.Optimise.update!(state, model, grads[1])
122+
Flux.update!(state, model, grads[1])
119123
end
120124

121125
#logging

0 commit comments

Comments
 (0)