Skip to content

Commit 52a7b89

Browse files
Merge pull request #328 from rossinerbe/patch-1
Update 60-minute-blitz.jl
2 parents c2b66f2 + cee5701 commit 52a7b89

File tree

1 file changed

+30
-33
lines changed

1 file changed

+30
-33
lines changed

tutorials/60-minute-blitz/60-minute-blitz.jl

Lines changed: 30 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -273,26 +273,29 @@ Flux.train!(loss, params(m), [(data,labels)], opt)
273273

274274
using Statistics
275275
using Flux, Flux.Optimise
276-
using Images: channelview
277-
using Metalhead
278-
using Metalhead: trainimgs, valimgs
276+
using MLDatasets: CIFAR10
279277
using Images.ImageCore
280-
using Flux: onehotbatch, onecold, flatten
278+
using Flux: onehotbatch, onecold
281279
using Base.Iterators: partition
282-
# using CUDA
280+
using CUDA
283281

284282
# The image will give us an idea of what we are dealing with.
285283
# ![title](https://pytorch.org/tutorials/_images/cifar10.png)
286284

287-
Metalhead.download(CIFAR10)
288-
X = trainimgs(CIFAR10)
289-
labels = onehotbatch([X[i].ground_truth.class for i in 1:50000],1:10)
285+
train_x, train_y = CIFAR10.traindata(Float32)
286+
labels = onehotbatch(train_y, 0:9)
287+
288+
#The train_x contains 50000 images converted to 32 X 32 X 3 arrays with the third
289+
# dimension being the 3 channels (R,G,B). Let's take a look at a random image from
290+
# the train_x. For this, we need to permute the dimensions to 3 X 32 X 32 and use
291+
# `colorview` to convert it back to an image.
290292

291293
# Let's take a look at a random image from the dataset
292294

293-
image(x) = x.img # handy for use later
294-
ground_truth(x) = x.ground_truth
295-
image.(X[rand(1:end, 10)])
295+
using Plots
296+
image(x) = colorview(RGB, permutedims(x, (3, 2, 1)))
297+
plot(image(train_x[:,:,:,rand(1:end)]))
298+
296299

297300
# The images are simply 32 X 32 matrices of numbers in 3 channels (R,G,B). We can now
298301
# arrange them in batches of say, 1000 and keep a validation set to track our progress.
@@ -302,19 +305,14 @@ image.(X[rand(1:end, 10)])
302305
# and train only on them. It is shown to help with escaping
303306
# [saddle points](https://en.wikipedia.org/wiki/Saddle_point).
304307

305-
# Defining a `getarray` function would help in converting the matrices to `Float` type.
306-
307-
getarray(X) = float.(permutedims(channelview(X), (3, 2, 1)))
308-
imgs = [getarray(X[i].img) for i in 1:50000]
309308

310309
# The first 49k images (in batches of 1000) will be our training set, and the rest is
311310
# for validation. `partition` handily breaks down the set we give it in consecutive parts
312-
# (1000 in this case). `cat` is a shorthand for concatenating multi-dimensional arrays along
313-
# any dimension.
311+
# (1000 in this case).
314312

315-
train = ([(cat(imgs[i]..., dims = 4), labels[:,i]) for i in partition(1:49000, 1000)]) |> gpu
313+
train = ([(train_x[:,:,:,i], labels[:,i]) for i in partition(1:49000, 1000)]) |> gpu
316314
valset = 49001:50000
317-
valX = cat(imgs[valset]..., dims = 4) |> gpu
315+
valX = train_x[:,:,:,valset] |> gpu
318316
valY = labels[:, valset] |> gpu
319317

320318
# ## Defining the Classifier
@@ -331,7 +329,7 @@ m = Chain(
331329
MaxPool((2,2)),
332330
Conv((5,5), 16=>8, relu),
333331
MaxPool((2,2)),
334-
flatten,
332+
x -> reshape(x, :, size(x, 4)),
335333
Dense(200, 120),
336334
Dense(120, 84),
337335
Dense(84, 10),
@@ -345,15 +343,15 @@ m = Chain(
345343
# preventing us from overshooting our desired destination.
346344
#-
347345

348-
using Flux: Momentum
346+
using Flux: crossentropy, Momentum
349347

350-
loss(x, y) = Flux.Losses.crossentropy(m(x), y)
348+
loss(x, y) = sum(crossentropy(m(x), y))
351349
opt = Momentum(0.01)
352350

353351
# We can start writing our train loop where we will keep track of some basic accuracy
354352
# numbers about our model. We can define an `accuracy` function for it like so.
355353

356-
accuracy(x, y) = mean(onecold(m(x), 1:10) .== onecold(y, 1:10))
354+
accuracy(x, y) = mean(onecold(m(x), 0:9) .== onecold(y, 0:9))
357355

358356
# ## Training
359357
# -----------
@@ -398,25 +396,24 @@ end
398396
# Okay, first step. Let us perform the exact same preprocessing on this set, as we did
399397
# on our training set.
400398

401-
valset = valimgs(CIFAR10)
402-
valimg = [getarray(valset[i].img) for i in 1:10000]
403-
labels = onehotbatch([valset[i].ground_truth.class for i in 1:10000],1:10)
404-
test = gpu.([(cat(valimg[i]..., dims = 4), labels[:,i]) for i in partition(1:10000, 1000)])
399+
test_x, test_y = CIFAR10.testdata(Float32)
400+
test_labels = onehotbatch(test_y, 0:9)
401+
402+
test = gpu.([(test_x[:,:,:,i], test_labels[:,i]) for i in partition(1:10000, 1000)])
405403

406-
# Next, display some of the images from the test set.
404+
# Next, display an image from the test set.
407405

408-
ids = rand(1:10000, 10)
409-
image.(valset[ids])
406+
plot(image(test_x[:,:,:,rand(1:end)]))
410407

411408
# The outputs are energies for the 10 classes. Higher the energy for a class, the more the
412409
# network thinks that the image is of the particular class. Every column corresponds to the
413410
# output of one image, with the 10 floats in the column being the energies.
414411

415412
# Let's see how the model fared.
416413

417-
rand_test = getarray.(image.(valset[ids]))
418-
rand_test = cat(rand_test..., dims = 4) |> gpu
419-
rand_truth = ground_truth.(valset[ids])
414+
ids = rand(1:10000, 5)
415+
rand_test = test_x[:,:,:,ids] |> gpu
416+
rand_truth = test_y[ids]
420417
m(rand_test)
421418

422419
# This looks similar to how we would expect the results to be. At this point, it's a good

0 commit comments

Comments
 (0)