11using Flux, MLDatasets
22using Flux: onehotbatch, onecold, DataLoader, flatten, OptimiserChain
33using BSON: @save ,@load
4- Flux . _old_to_new (rule :: ClipNorm ) = Flux . Optimisers . ClipNorm (rule . thresh) # wrong in Flux 0.13.9
4+ ENV [ " DATADEPS_ALWAYS_ACCEPT " ] = " true "
55
6+ # wrong in Flux 0.13.9
7+ Flux. _old_to_new (rule:: ClipNorm ) = Flux. Optimisers. ClipNorm (rule. thresh)
8+
9+ # Also, quick test of train(epochs=10, images=128) shows increasing loss, not sure why.
610
711function ConvMixer (in_channels, kernel_size, patch_size, dim, depth, N_classes)
812 f = Chain (
@@ -22,19 +26,19 @@ function ConvMixer(in_channels, kernel_size, patch_size, dim, depth, N_classes)
2226 return f
2327end
2428
25- function get_data (batchsize; dataset = MLDatasets . CIFAR10, idxs = nothing )
26- """
27- idxs=nothing gives the full dataset, otherwise (for testing purposes) only the 1:idxs elements of the train set are given .
28- """
29- ENV [ " DATADEPS_ALWAYS_ACCEPT " ] = " true "
29+ """
30+ By default gives the full dataset, keyword images gives (for testing purposes)
31+ only the 1:images elements of the train set.
32+ """
33+ function get_data (batchsize; dataset = MLDatasets . CIFAR10, images = :)
3034
3135 # Loading Dataset
32- if idxs === nothing
36+ if images === (:)
3337 xtrain, ytrain = dataset (:train )[:]
3438 xtest, ytest = dataset (:test )[:]
3539 else
36- xtrain, ytrain = dataset (:train )[1 : idxs ]
37- xtest, ytest = dataset (:test )[1 : Int (idxs / 10 )]
40+ xtrain, ytrain = dataset (:train )[1 : images ]
41+ xtest, ytest = dataset (:test )[1 : (images ÷ 10 )]
3842 end
3943
4044 # Reshape Data to comply to Julia's (width, height, channels, batch_size) convention in case there are only 1 channel (eg MNIST)
@@ -74,10 +78,10 @@ function create_loss_function(dataloader, device)
7478end
7579
7680
77- function train (n_epochs = 100 )
81+ function train (; epochs = 100 , images = : )
7882
7983 # params: warning, the training can be long with these params
80- train_loader, test_loader = get_data (128 )
84+ train_loader, test_loader = get_data (128 ; images )
8185 η = 3f-4
8286 in_channel = 3
8387 patch_size = 2
@@ -88,8 +92,8 @@ function train(n_epochs=100)
8892 use_cuda = true
8993
9094 # logging the losses
91- train_save = zeros (n_epochs , 2 )
92- test_save = zeros (n_epochs , 2 )
95+ train_save = zeros (epochs , 2 )
96+ test_save = zeros (epochs , 2 )
9397
9498 if use_cuda
9599 device = gpu
@@ -111,11 +115,11 @@ function train(n_epochs=100)
111115 )
112116 state = Flux. setup (opt, model)
113117
114- for epoch in 1 : n_epochs
118+ for epoch in 1 : epochs
115119 for (x,y) in train_loader
116120 x,y = x|> device, y|> device
117121 grads = gradient (m-> Flux. logitcrossentropy (m (x), y, agg= sum), model)
118- Flux. Optimise . update! (state, model, grads[1 ])
122+ Flux. update! (state, model, grads[1 ])
119123 end
120124
121125 # logging
0 commit comments