|
34 | 34 | # --base-path="/home/avik-pal/data/ImageNet/" |
35 | 35 | # ``` |
36 | 36 |
|
37 | | -# ## Package Imports |
38 | | - |
39 | | -using Boltz, Lux, MLDataDevices |
40 | | -## import Metalhead # Install and load this package to use the Metalhead models with Lux |
41 | | - |
42 | | -using Dates, Random |
43 | | -using DataAugmentation, |
44 | | - FileIO, MLUtils, OneHotArrays, Optimisers, ParameterSchedulers, Setfield |
45 | | -using Comonicon, Format |
46 | | -using JLD2 |
47 | | -using Zygote |
48 | | - |
49 | | -using LuxCUDA |
50 | | -## using AMDGPU # Install and load AMDGPU to train models on AMD GPUs with ROCm |
51 | | -using MPI: MPI |
52 | | -## Enables distributed training in Lux. NCCL is needed for CUDA GPUs |
53 | | -using NCCL: NCCL |
54 | | - |
55 | | -const gdev = gpu_device() |
56 | | -const cdev = cpu_device() |
57 | 37 |
|
58 | 38 | # ## Setup Distributed Training |
59 | 39 |
|
|
82 | 62 | const is_distributed = total_workers > 1 |
83 | 63 | const should_log = !is_distributed || local_rank == 0 |
84 | 64 |
|
85 | | -# ## Data Loading for ImageNet |
86 | | - |
87 | | -## We need the data to be in a specific format. See the |
88 | | -## [README.md](@__REPO_ROOT_URL__/examples/ImageNet/README.md) for more details. |
89 | | - |
90 | | -const IMAGENET_CORRUPTED_FILES = [ |
91 | | - "n01739381_1309.JPEG", |
92 | | - "n02077923_14822.JPEG", |
93 | | - "n02447366_23489.JPEG", |
94 | | - "n02492035_15739.JPEG", |
95 | | - "n02747177_10752.JPEG", |
96 | | - "n03018349_4028.JPEG", |
97 | | - "n03062245_4620.JPEG", |
98 | | - "n03347037_9675.JPEG", |
99 | | - "n03467068_12171.JPEG", |
100 | | - "n03529860_11437.JPEG", |
101 | | - "n03544143_17228.JPEG", |
102 | | - "n03633091_5218.JPEG", |
103 | | - "n03710637_5125.JPEG", |
104 | | - "n03961711_5286.JPEG", |
105 | | - "n04033995_2932.JPEG", |
106 | | - "n04258138_17003.JPEG", |
107 | | - "n04264628_27969.JPEG", |
108 | | - "n04336792_7448.JPEG", |
109 | | - "n04371774_5854.JPEG", |
110 | | - "n04596742_4225.JPEG", |
111 | | - "n07583066_647.JPEG", |
112 | | - "n13037406_4650.JPEG", |
113 | | - "n02105855_2933.JPEG", |
114 | | - "ILSVRC2012_val_00019877.JPEG", |
115 | | -] |
116 | | - |
117 | | -function load_imagenet1k(base_path::String, split::Symbol) |
118 | | - @assert split in (:train, :val) |
119 | | - full_path = joinpath(base_path, string(split)) |
120 | | - synsets = sort(readdir(full_path)) |
121 | | - @assert length(synsets) == 1000 "There should be 1000 subdirectories in $(full_path)." |
122 | | - |
123 | | - image_files = String[] |
124 | | - labels = Int[] |
125 | | - for (i, synset) in enumerate(synsets) |
126 | | - filenames = readdir(joinpath(full_path, synset)) |
127 | | - filter!(x -> x ∉ IMAGENET_CORRUPTED_FILES, filenames) |
128 | | - paths = joinpath.((full_path,), (synset,), filenames) |
129 | | - append!(image_files, paths) |
130 | | - append!(labels, repeat([i - 1], length(paths))) |
131 | | - end |
132 | | - |
133 | | - return image_files, labels |
134 | | -end |
135 | | - |
136 | | -default_image_size(::Type{Vision.VisionTransformer}, ::Nothing) = 256 |
137 | | -default_image_size(::Type{Vision.VisionTransformer}, size::Int) = size |
138 | | -default_image_size(_, ::Nothing) = 224 |
139 | | -default_image_size(_, size::Int) = size |
140 | | - |
141 | | -struct MakeColoredImage <: DataAugmentation.Transform end |
142 | | - |
143 | | -function DataAugmentation.apply( |
144 | | - ::MakeColoredImage, item::DataAugmentation.AbstractArrayItem; randstate=nothing |
145 | | -) |
146 | | - data = itemdata(item) |
147 | | - (ndims(data) == 2 || size(data, 3) == 1) && (data = cat(data, data, data; dims=Val(3))) |
148 | | - return DataAugmentation.setdata(item, data) |
149 | | -end |
150 | | - |
151 | | -struct FileDataset |
152 | | - files |
153 | | - labels |
154 | | - augment |
155 | | -end |
156 | | - |
157 | | -Base.length(dataset::FileDataset) = length(dataset.files) |
158 | | - |
159 | | -function Base.getindex(dataset::FileDataset, i::Int) |
160 | | - img = Image(FileIO.load(dataset.files[i])) |
161 | | - aug_img = itemdata(DataAugmentation.apply(dataset.augment, img)) |
162 | | - return aug_img, OneHotArrays.onehot(dataset.labels[i], 0:999) |
163 | | -end |
164 | | - |
165 | | -function construct_dataloaders(; |
166 | | - base_path::String, train_batchsize, val_batchsize, image_size::Int |
167 | | -) |
168 | | - sensible_println("=> creating dataloaders.") |
169 | | - |
170 | | - train_augment = |
171 | | - ScaleFixed((256, 256)) |> |
172 | | - Maybe(FlipX(), 0.5) |> |
173 | | - RandomResizeCrop((image_size, image_size)) |> |
174 | | - PinOrigin() |> |
175 | | - ImageToTensor() |> |
176 | | - MakeColoredImage() |> |
177 | | - ToEltype(Float32) |> |
178 | | - Normalize((0.485f0, 0.456f0, 0.406f0), (0.229f0, 0.224f0, 0.225f0)) |
179 | | - train_files, train_labels = load_imagenet1k(base_path, :train) |
180 | | - |
181 | | - train_dataset = FileDataset(train_files, train_labels, train_augment) |
182 | | - |
183 | | - val_augment = |
184 | | - ScaleFixed((image_size, image_size)) |> |
185 | | - PinOrigin() |> |
186 | | - ImageToTensor() |> |
187 | | - MakeColoredImage() |> |
188 | | - ToEltype(Float32) |> |
189 | | - Normalize((0.485f0, 0.456f0, 0.406f0), (0.229f0, 0.224f0, 0.225f0)) |
190 | | - val_files, val_labels = load_imagenet1k(base_path, :val) |
191 | | - |
192 | | - val_dataset = FileDataset(val_files, val_labels, val_augment) |
193 | | - |
194 | | - if is_distributed |
195 | | - train_dataset = DistributedUtils.DistributedDataContainer( |
196 | | - distributed_backend, train_dataset |
197 | | - ) |
198 | | - val_dataset = DistributedUtils.DistributedDataContainer( |
199 | | - distributed_backend, val_dataset |
200 | | - ) |
201 | | - end |
202 | | - |
203 | | - train_dataloader = DataLoader( |
204 | | - train_dataset; |
205 | | - batchsize=train_batchsize ÷ total_workers, |
206 | | - partial=false, |
207 | | - collate=true, |
208 | | - shuffle=true, |
209 | | - parallel=true, |
210 | | - ) |
211 | | - val_dataloader = DataLoader( |
212 | | - val_dataset; |
213 | | - batchsize=val_batchsize ÷ total_workers, |
214 | | - partial=true, |
215 | | - collate=true, |
216 | | - shuffle=false, |
217 | | - parallel=true, |
218 | | - ) |
219 | | - |
220 | | - return gdev(train_dataloader), gdev(val_dataloader) |
221 | | -end |
222 | | - |
223 | | -# ## Model Construction |
224 | | - |
225 | | -function construct_model(; |
226 | | - rng::AbstractRNG, model_name::String, model_args, pretrained::Bool=false |
227 | | -) |
228 | | - model = getproperty(Vision, Symbol(model_name))(model_args...; pretrained) |
229 | | - ps, st = Lux.setup(rng, model) |> gdev |
230 | | - |
231 | | - sensible_println("=> model `$(model_name)` created.") |
232 | | - pretrained && sensible_println("==> using pre-trained model`") |
233 | | - sensible_println("==> number of trainable parameters: $(Lux.parameterlength(ps))") |
234 | | - sensible_println("==> number of states: $(Lux.statelength(st))") |
235 | | - |
236 | | - if is_distributed |
237 | | - ps = DistributedUtils.synchronize!!(distributed_backend, ps) |
238 | | - st = DistributedUtils.synchronize!!(distributed_backend, st) |
239 | | - sensible_println("==> synced model parameters and states across all ranks") |
240 | | - end |
241 | | - |
242 | | - return model, ps, st |
243 | | -end |
244 | 65 |
|
245 | 66 | # ## Optimizer Configuration |
246 | 67 |
|
247 | | -function construct_optimizer_and_scheduler(; |
248 | | - kind::String, |
249 | | - learning_rate::AbstractFloat, |
250 | | - nesterov::Bool, |
251 | | - momentum::AbstractFloat, |
252 | | - weight_decay::AbstractFloat, |
253 | | - scheduler_kind::String, |
254 | | - cycle_length::Int, |
255 | | - damp_factor::AbstractFloat, |
256 | | - lr_step_decay::AbstractFloat, |
257 | | - lr_step::Vector{Int}, |
258 | | -) |
259 | | - sensible_println("=> creating optimizer.") |
260 | | - |
261 | | - kind = Symbol(kind) |
262 | | - optimizer = if kind == :adam |
263 | | - Adam(learning_rate) |
264 | | - elseif kind == :sgd |
265 | | - if nesterov |
266 | | - Nesterov(learning_rate, momentum) |
267 | | - elseif iszero(momentum) |
268 | | - Descent(learning_rate) |
269 | | - else |
270 | | - Momentum(learning_rate, momentum) |
271 | | - end |
272 | | - else |
273 | | - throw(ArgumentError("Unknown value for `optimizer` = $kind. Supported options are: \ |
274 | | - `adam` and `sgd`.")) |
275 | | - end |
276 | | - |
277 | | - optimizer = if iszero(weight_decay) |
278 | | - optimizer |
279 | | - else |
280 | | - OptimiserChain(optimizer, WeightDecay(weight_decay)) |
281 | | - end |
282 | | - |
283 | | - sensible_println("=> creating scheduler.") |
284 | | - |
285 | | - scheduler_kind = Symbol(scheduler_kind) |
286 | | - scheduler = if scheduler_kind == :cosine |
287 | | - l0 = learning_rate |
288 | | - l1 = learning_rate / 100 |
289 | | - ComposedSchedule( |
290 | | - CosAnneal(l0, l1, cycle_length), Step(l0, damp_factor, cycle_length) |
291 | | - ) |
292 | | - elseif scheduler_kind == :constant |
293 | | - Constant(learning_rate) |
294 | | - elseif scheduler_kind == :step |
295 | | - Step(learning_rate, lr_step_decay, lr_step) |
296 | | - else |
297 | | - throw(ArgumentError("Unknown value for `lr_scheduler` = $(scheduler_kind). \ |
298 | | - Supported options are: `constant`, `step` and `cosine`.")) |
299 | | - end |
300 | | - |
301 | | - optimizer = if is_distributed |
302 | | - DistributedUtils.DistributedOptimizer(distributed_backend, optimizer) |
303 | | - else |
304 | | - optimizer |
305 | | - end |
306 | | - |
307 | | - return optimizer, scheduler |
308 | | -end |
309 | 68 |
|
310 | 69 | # ## Utility Functions |
311 | 70 |
|
|
0 commit comments