-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy path001_aim_v1_600M.yaml
More file actions
271 lines (268 loc) · 8.75 KB
/
001_aim_v1_600M.yaml
File metadata and controls
271 lines (268 loc) · 8.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
constants:
# ========== VISION ENCODER ARGS ==========
vision_embed_dim: 1536
vision_num_blocks: 24
vision_num_heads: 12 # Assuming head_dim=128
# ========== MLP HEAD ARGS ================
head_embed_dim: 2048
head_num_blocks: 12
# ========== DATA ==========
img_size: 224
batch_size_per_gpu: 256 # global_batch_size = batch_size_per_gpu * WORLD_SIZE
patch_size: 14
num_patches: 256 # num_patches = (img_size / patch_size) ** 2
seed: 0
prefix_range: [1, 255] # We sample the prefix_len uniformly from this range
# ========== TRANSFORMS ==========
train_transform:
_target_: torchvision.transforms.Compose
transforms:
- _target_: torchvision.transforms.RandomResizedCrop
size: ${constants.img_size}
scale: [0.4, 1.0]
interpolation: 3 # bicubic
- _target_: torchvision.transforms.RandomHorizontalFlip
- _target_: torchvision.transforms.RandomApply
p: 0.8
transforms:
- _target_: torchvision.transforms.ColorJitter
brightness: 0.4
contrast: 0.4
saturation: 0.2
hue: 0.1
- _target_: torchvision.transforms.ToTensor
- _target_: torchvision.transforms.Normalize
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
eval_transform:
_target_: torchvision.transforms.Compose
transforms:
- _target_: torchvision.transforms.Resize
size: 256 # crop ratio = 0.875
interpolation: 3
- _target_: torchvision.transforms.CenterCrop
size: ${constants.img_size}
- _target_: torchvision.transforms.ToTensor
- _target_: torchvision.transforms.Normalize
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
# ========== GENERATORS ==========
train_generators:
# Samples a `prefix_len`` uniformly from `prefix_range`. The `prefix_len` is different per sample.
- _target_: l3m.helpers.vision.mask_generator.RandomRasterMasking
write_key: encoder_prefix_mask
num_patches: ${constants.num_patches}
prefix_range: ${constants.prefix_range}
eval_generators:
- _target_: l3m.helpers.vision.mask_generator.RandomRasterMasking
write_key: encoder_prefix_mask
num_patches: ${constants.num_patches}
force_full_attn: true
# ==============================================================
# DATA
# ==============================================================
data:
data_set:
- imagenet
data_path:
nb_classes: 1000
# ========== TRAIN ==========
train:
dataset:
_target_: l3m.data.dataset_with_generator.DatasetWithGenerator
generators: ${constants.train_generators}
base_dataset:
_target_: l3m.data.vision.datasets.image_folder.CustomImageFolder
root: /mnt/data/imagenet/train
transform: ${constants.train_transform}
dataloader:
_target_: torch.utils.data.DataLoader
_partial_: true
batch_size: ${constants.batch_size_per_gpu}
num_workers: 10
pin_memory: true
drop_last: true
# ========== VALIDATION ==========
validation:
dataset:
imagenet:
_target_: l3m.data.dataset_with_generator.DatasetWithGenerator
generators: ${constants.eval_generators}
base_dataset:
_target_: l3m.data.vision.datasets.image_folder.CustomImageFolder
root: /mnt/data/imagenet/val
transform: ${constants.eval_transform}
dataloader:
_target_: torch.utils.data.DataLoader
_partial_: true
batch_size: ${constants.batch_size_per_gpu}
num_workers: 10
pin_memory: true
drop_last: true
# ==============================================================
# OPTIM
# ==============================================================
optim:
# `batch_size_per_gpu`` is chunked into `gradient_accumulation_steps` chunks, and gradients are accumulated over these chunks.
# NOTE: `gradient_accumulation_steps` can be changed **independently** of the batch-size since we are chunking.
gradient_accumulation_steps: 1
grad_clip: 1.0
optimizer:
_target_: torch.optim.AdamW
_partial_: true
lr: 1e-3
betas:
- 0.9
- 0.95
eps: 1e-08
weight_decay: 0.05
fused: true # Can be significantly faster
wd_exclude:
- '*bias*'
- '*pos_embed*'
- '*norm*'
- '*gamma*'
scheduler:
_target_: fvcore.common.param_scheduler.CompositeParamScheduler
schedulers:
- _target_: fvcore.common.param_scheduler.LinearParamScheduler
start_value: 1e-6
end_value: ${optim.optimizer.lr}
- _target_: fvcore.common.param_scheduler.CosineParamScheduler
start_value: ${optim.optimizer.lr}
end_value: 0.0
interval_scaling:
- rescaled
- rescaled
lengths:
- 0.025
- 0.975
# ==============================================================
# MODEL
# ==============================================================
model:
checkpoint:
meta_model:
_target_: l3m.model.meta_models.MetaModel
preprocessor:
_target_: l3m.model.utils.MultiBlock
blocks:
- _target_: l3m.model.preprocessors.vision.ViTPreprocessor
read_key: image
write_key: image_tokens
patchifier:
_target_: l3m.model.preprocessors.vision.PatchEmbed
img_size: ${constants.img_size}
patch_size: ${constants.patch_size}
in_chans: 3
embed_dim: ${constants.vision_embed_dim}
norm_layer:
_target_: l3m.model.layers.normalization.LayerNormFP32
_partial_: true
eps: 1e-5
pos_embed_type: sincos
drop_patches: false
cls_token: false
# Build flex-attn `BlockMask` based on the `prefix_len` generated by `RandomRasterMasking`
- _target_: l3m.model.preprocessors.mask_builders.FlexAttnPrefixAttentionMaskBuilder
read_key: encoder_prefix_mask
write_key: encoder_prefix_attn_mask
trunk:
_target_: l3m.model.trunks.transformer.Transformer
read_key: image_tokens
write_key: image_tokens
self_attn_mask_read_key: encoder_prefix_attn_mask
embed_dim: ${constants.vision_embed_dim}
num_blocks: ${constants.vision_num_blocks}
mlp_ratio: 4
norm_layer:
_target_: l3m.model.layers.normalization.LayerNormFP32
_partial_: true
eps: 1e-5
attn_target:
_target_: l3m.model.layers.attention.GenericAttention
_partial_: true
dim: ${constants.vision_embed_dim}
num_heads: ${constants.vision_num_heads}
qkv_bias: false
use_flex_attention: true # Needed to pre-compile flex-attention
weight_init_style: xavier_uniform
post_trunk_norm: true
use_bias: false
postprocessor:
_target_: torch.nn.Identity
head:
_target_: l3m.model.heads.simple_decoder.SimpleDecoder
read_key: image_tokens
write_key: image_preds
attn_target:
_target_: l3m.model.layers.misc.GenericIdentity # No Attention - MLP only
_partial_: true
input_dim: ${constants.vision_embed_dim}
embed_dim: ${constants.head_embed_dim}
output_dim: 588
num_layers: ${constants.head_num_blocks}
mlp_ratio: 4
norm_layer:
_target_: l3m.model.layers.normalization.LayerNormFP32
_partial_: true
eps: 1e-5
use_bias: false
patch_size: ${constants.patch_size}
image_size: ${constants.img_size}
# ==============================================================
# LOSS
# ==============================================================
loss:
_target_: l3m.loss.mae_loss.RasterPixelLoss
read_key: image_preds
norm_pix_loss: true
mask_read_key: encoder_prefix_mask
patch_size: ${constants.patch_size}
val_loss:
_target_: l3m.loss.mae_loss.RasterPixelLoss
read_key: image_preds
norm_pix_loss: true
mask_read_key: encoder_prefix_mask
patch_size: ${constants.patch_size}
# ==============================================================
# EXPERIMENT
# ==============================================================
experiment:
start_iteration: 0
total_iterations: 125_000
ckpt_save_freq: 25_000
test_frequency: 25_000
torch_compile: true
dtype: bfloat16
output_dir:
device: cuda
find_unused_parameters: false
seed: ${constants.seed}
dist_eval: true
distributed: false # will be automatically enabled
world_size: 1 # will be automatically updated
dist_url: env://
eval: false
resume: # Should be populated with a ckpt path on job resubmission
amp_enabled: true
no_sync_gradient_accumulation: false
fsdp:
sharding_strategy: NO_SHARD
param_dtype: bf16
reduce_dtype: fp32
buffer_dtype: fp32
fsdp_activation_checkpointing: true
nccl_timeout_mins: 120
activation_checkpoint_mode: full
fsdp_layers_to_wrap:
- Block
activation_checkpoint_template:
- trunk.blocks
fsdp_ignored_modules:
- ''
wandb:
use_wandb: true
watch_freq: 500
project: l3m
tags: [aimv1]