Skip to content

Commit 672308f

Browse files
committed
feat: migrate imagenet example to reactant
1 parent 371e1bb commit 672308f

File tree

4 files changed

+329
-254
lines changed

4 files changed

+329
-254
lines changed

examples/ImageNet/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
tiny-imagenet-200*

examples/ImageNet/Project.toml

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,29 @@
11
[deps]
22
Boltz = "4544d5e4-abc5-4dea-817f-29e4c205d9c8"
3-
Comonicon = "863f3e99-da2a-4334-8734-de3dacbe5542"
43
DataAugmentation = "88a5189c-e7ff-4f85-ac6b-e6158070f02e"
54
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
65
FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
76
Format = "1fa38f19-a742-5d3f-a2b9-30dd87b9d5f8"
87
ImageIO = "82e4d734-157c-48bb-816b-45c225c6df19"
98
ImageMagick = "6218d12a-5da1-5696-b52f-db25d2ecc6d1"
9+
ImageTransformations = "02fcd773-0e25-5acc-982a-7f6622650795"
1010
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
1111
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
12-
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
1312
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
14-
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
15-
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
16-
NCCL = "3fe64909-d7a1-4096-9b7d-7a0f12cf0f6b"
13+
OhMyThreads = "67456a42-1dca-4109-a031-0a68de7e3ad5"
1714
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
1815
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
1916
ParameterSchedulers = "d7d3b36b-41b8-4d0d-a2bf-768c6151755e"
2017
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
18+
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
2119
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
2220
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2321

22+
[sources]
23+
Lux = {path = "../.."}
24+
2425
[compat]
2526
Boltz = "1"
26-
Comonicon = "1"
2727
DataAugmentation = "0.3"
2828
Dates = "1.10"
2929
FileIO = "1.16"
@@ -32,17 +32,10 @@ ImageIO = "0.6"
3232
ImageMagick = "1"
3333
JLD2 = "0.5.1, 0.6"
3434
Lux = "1"
35-
LuxCUDA = "0.3.3"
3635
MLDataDevices = "1.17"
37-
MLUtils = "0.4.4"
38-
MPI = "0.20.21"
39-
NCCL = "0.1.2"
4036
OneHotArrays = "0.2.5"
4137
Optimisers = "0.4.6"
4238
ParameterSchedulers = "0.4.2"
4339
Random = "1.10"
4440
Setfield = "1.1.1"
4541
Zygote = "0.7"
46-
47-
[sources]
48-
Lux = {path = "../../"}

examples/ImageNet/main.jl

Lines changed: 0 additions & 241 deletions
Original file line numberDiff line numberDiff line change
@@ -34,26 +34,6 @@
3434
# --base-path="/home/avik-pal/data/ImageNet/"
3535
# ```
3636

37-
# ## Package Imports
38-
39-
using Boltz, Lux, MLDataDevices
40-
## import Metalhead # Install and load this package to use the Metalhead models with Lux
41-
42-
using Dates, Random
43-
using DataAugmentation,
44-
FileIO, MLUtils, OneHotArrays, Optimisers, ParameterSchedulers, Setfield
45-
using Comonicon, Format
46-
using JLD2
47-
using Zygote
48-
49-
using LuxCUDA
50-
## using AMDGPU # Install and load AMDGPU to train models on AMD GPUs with ROCm
51-
using MPI: MPI
52-
## Enables distributed training in Lux. NCCL is needed for CUDA GPUs
53-
using NCCL: NCCL
54-
55-
const gdev = gpu_device()
56-
const cdev = cpu_device()
5737

5838
# ## Setup Distributed Training
5939

@@ -82,230 +62,9 @@ end
8262
const is_distributed = total_workers > 1
8363
const should_log = !is_distributed || local_rank == 0
8464

85-
# ## Data Loading for ImageNet
86-
87-
## We need the data to be in a specific format. See the
88-
## [README.md](@__REPO_ROOT_URL__/examples/ImageNet/README.md) for more details.
89-
90-
const IMAGENET_CORRUPTED_FILES = [
91-
"n01739381_1309.JPEG",
92-
"n02077923_14822.JPEG",
93-
"n02447366_23489.JPEG",
94-
"n02492035_15739.JPEG",
95-
"n02747177_10752.JPEG",
96-
"n03018349_4028.JPEG",
97-
"n03062245_4620.JPEG",
98-
"n03347037_9675.JPEG",
99-
"n03467068_12171.JPEG",
100-
"n03529860_11437.JPEG",
101-
"n03544143_17228.JPEG",
102-
"n03633091_5218.JPEG",
103-
"n03710637_5125.JPEG",
104-
"n03961711_5286.JPEG",
105-
"n04033995_2932.JPEG",
106-
"n04258138_17003.JPEG",
107-
"n04264628_27969.JPEG",
108-
"n04336792_7448.JPEG",
109-
"n04371774_5854.JPEG",
110-
"n04596742_4225.JPEG",
111-
"n07583066_647.JPEG",
112-
"n13037406_4650.JPEG",
113-
"n02105855_2933.JPEG",
114-
"ILSVRC2012_val_00019877.JPEG",
115-
]
116-
117-
function load_imagenet1k(base_path::String, split::Symbol)
118-
@assert split in (:train, :val)
119-
full_path = joinpath(base_path, string(split))
120-
synsets = sort(readdir(full_path))
121-
@assert length(synsets) == 1000 "There should be 1000 subdirectories in $(full_path)."
122-
123-
image_files = String[]
124-
labels = Int[]
125-
for (i, synset) in enumerate(synsets)
126-
filenames = readdir(joinpath(full_path, synset))
127-
filter!(x -> x IMAGENET_CORRUPTED_FILES, filenames)
128-
paths = joinpath.((full_path,), (synset,), filenames)
129-
append!(image_files, paths)
130-
append!(labels, repeat([i - 1], length(paths)))
131-
end
132-
133-
return image_files, labels
134-
end
135-
136-
default_image_size(::Type{Vision.VisionTransformer}, ::Nothing) = 256
137-
default_image_size(::Type{Vision.VisionTransformer}, size::Int) = size
138-
default_image_size(_, ::Nothing) = 224
139-
default_image_size(_, size::Int) = size
140-
141-
struct MakeColoredImage <: DataAugmentation.Transform end
142-
143-
function DataAugmentation.apply(
144-
::MakeColoredImage, item::DataAugmentation.AbstractArrayItem; randstate=nothing
145-
)
146-
data = itemdata(item)
147-
(ndims(data) == 2 || size(data, 3) == 1) && (data = cat(data, data, data; dims=Val(3)))
148-
return DataAugmentation.setdata(item, data)
149-
end
150-
151-
struct FileDataset
152-
files
153-
labels
154-
augment
155-
end
156-
157-
Base.length(dataset::FileDataset) = length(dataset.files)
158-
159-
function Base.getindex(dataset::FileDataset, i::Int)
160-
img = Image(FileIO.load(dataset.files[i]))
161-
aug_img = itemdata(DataAugmentation.apply(dataset.augment, img))
162-
return aug_img, OneHotArrays.onehot(dataset.labels[i], 0:999)
163-
end
164-
165-
function construct_dataloaders(;
166-
base_path::String, train_batchsize, val_batchsize, image_size::Int
167-
)
168-
sensible_println("=> creating dataloaders.")
169-
170-
train_augment =
171-
ScaleFixed((256, 256)) |>
172-
Maybe(FlipX(), 0.5) |>
173-
RandomResizeCrop((image_size, image_size)) |>
174-
PinOrigin() |>
175-
ImageToTensor() |>
176-
MakeColoredImage() |>
177-
ToEltype(Float32) |>
178-
Normalize((0.485f0, 0.456f0, 0.406f0), (0.229f0, 0.224f0, 0.225f0))
179-
train_files, train_labels = load_imagenet1k(base_path, :train)
180-
181-
train_dataset = FileDataset(train_files, train_labels, train_augment)
182-
183-
val_augment =
184-
ScaleFixed((image_size, image_size)) |>
185-
PinOrigin() |>
186-
ImageToTensor() |>
187-
MakeColoredImage() |>
188-
ToEltype(Float32) |>
189-
Normalize((0.485f0, 0.456f0, 0.406f0), (0.229f0, 0.224f0, 0.225f0))
190-
val_files, val_labels = load_imagenet1k(base_path, :val)
191-
192-
val_dataset = FileDataset(val_files, val_labels, val_augment)
193-
194-
if is_distributed
195-
train_dataset = DistributedUtils.DistributedDataContainer(
196-
distributed_backend, train_dataset
197-
)
198-
val_dataset = DistributedUtils.DistributedDataContainer(
199-
distributed_backend, val_dataset
200-
)
201-
end
202-
203-
train_dataloader = DataLoader(
204-
train_dataset;
205-
batchsize=train_batchsize ÷ total_workers,
206-
partial=false,
207-
collate=true,
208-
shuffle=true,
209-
parallel=true,
210-
)
211-
val_dataloader = DataLoader(
212-
val_dataset;
213-
batchsize=val_batchsize ÷ total_workers,
214-
partial=true,
215-
collate=true,
216-
shuffle=false,
217-
parallel=true,
218-
)
219-
220-
return gdev(train_dataloader), gdev(val_dataloader)
221-
end
222-
223-
# ## Model Construction
224-
225-
function construct_model(;
226-
rng::AbstractRNG, model_name::String, model_args, pretrained::Bool=false
227-
)
228-
model = getproperty(Vision, Symbol(model_name))(model_args...; pretrained)
229-
ps, st = Lux.setup(rng, model) |> gdev
230-
231-
sensible_println("=> model `$(model_name)` created.")
232-
pretrained && sensible_println("==> using pre-trained model`")
233-
sensible_println("==> number of trainable parameters: $(Lux.parameterlength(ps))")
234-
sensible_println("==> number of states: $(Lux.statelength(st))")
235-
236-
if is_distributed
237-
ps = DistributedUtils.synchronize!!(distributed_backend, ps)
238-
st = DistributedUtils.synchronize!!(distributed_backend, st)
239-
sensible_println("==> synced model parameters and states across all ranks")
240-
end
241-
242-
return model, ps, st
243-
end
24465

24566
# ## Optimizer Configuration
24667

247-
function construct_optimizer_and_scheduler(;
248-
kind::String,
249-
learning_rate::AbstractFloat,
250-
nesterov::Bool,
251-
momentum::AbstractFloat,
252-
weight_decay::AbstractFloat,
253-
scheduler_kind::String,
254-
cycle_length::Int,
255-
damp_factor::AbstractFloat,
256-
lr_step_decay::AbstractFloat,
257-
lr_step::Vector{Int},
258-
)
259-
sensible_println("=> creating optimizer.")
260-
261-
kind = Symbol(kind)
262-
optimizer = if kind == :adam
263-
Adam(learning_rate)
264-
elseif kind == :sgd
265-
if nesterov
266-
Nesterov(learning_rate, momentum)
267-
elseif iszero(momentum)
268-
Descent(learning_rate)
269-
else
270-
Momentum(learning_rate, momentum)
271-
end
272-
else
273-
throw(ArgumentError("Unknown value for `optimizer` = $kind. Supported options are: \
274-
`adam` and `sgd`."))
275-
end
276-
277-
optimizer = if iszero(weight_decay)
278-
optimizer
279-
else
280-
OptimiserChain(optimizer, WeightDecay(weight_decay))
281-
end
282-
283-
sensible_println("=> creating scheduler.")
284-
285-
scheduler_kind = Symbol(scheduler_kind)
286-
scheduler = if scheduler_kind == :cosine
287-
l0 = learning_rate
288-
l1 = learning_rate / 100
289-
ComposedSchedule(
290-
CosAnneal(l0, l1, cycle_length), Step(l0, damp_factor, cycle_length)
291-
)
292-
elseif scheduler_kind == :constant
293-
Constant(learning_rate)
294-
elseif scheduler_kind == :step
295-
Step(learning_rate, lr_step_decay, lr_step)
296-
else
297-
throw(ArgumentError("Unknown value for `lr_scheduler` = $(scheduler_kind). \
298-
Supported options are: `constant`, `step` and `cosine`."))
299-
end
300-
301-
optimizer = if is_distributed
302-
DistributedUtils.DistributedOptimizer(distributed_backend, optimizer)
303-
else
304-
optimizer
305-
end
306-
307-
return optimizer, scheduler
308-
end
30968

31069
# ## Utility Functions
31170

0 commit comments

Comments
 (0)