Skip to content

Commit 2edefcb

Browse files
authored
Add annealing to main training file (#72)
1 parent 225dd20 commit 2edefcb

File tree

3 files changed

+62
-16
lines changed

3 files changed

+62
-16
lines changed

app/vjepa/train.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,11 @@ def main(args, resume_preempt=False):
129129

130130
# -- OPTIMIZATION
131131
cfgs_opt = args.get("optimization")
132+
is_anneal = cfgs_opt.get("is_anneal", False)
133+
anneal_ckpt = cfgs_opt.get("anneal_ckpt", None)
134+
if is_anneal and anneal_ckpt is None:
135+
raise ValueError("Must specify anneal_ckpt if is_anneal is True")
136+
resume_anneal = cfgs_opt.get("resume_anneal", False) or (is_anneal and resume_preempt)
132137
ipe = cfgs_opt.get("ipe", None)
133138
ipe_scale = cfgs_opt.get("ipe_scale", 1.0)
134139
wd = float(cfgs_opt.get("weight_decay"))
@@ -169,7 +174,14 @@ def main(args, resume_preempt=False):
169174
latest_path = os.path.join(folder, latest_file)
170175
load_path = None
171176
if load_model:
172-
load_path = os.path.join(folder, r_file) if r_file is not None else latest_path
177+
if is_anneal:
178+
if os.path.exists(latest_path) and resume_anneal:
179+
load_path = latest_path
180+
else:
181+
load_path = anneal_ckpt
182+
resume_anneal = False
183+
else:
184+
load_path = r_file if r_file is not None else latest_path
173185
if not os.path.exists(load_path):
174186
load_path = None
175187
load_model = False
@@ -261,6 +273,7 @@ def main(args, resume_preempt=False):
261273

262274
# -- init optimizer and scheduler
263275
optimizer, scaler, scheduler, wd_scheduler = init_opt(
276+
is_anneal=is_anneal,
264277
encoder=encoder,
265278
predictor=predictor,
266279
wd=wd,
@@ -305,12 +318,14 @@ def main(args, resume_preempt=False):
305318
target_encoder=target_encoder,
306319
opt=optimizer,
307320
scaler=scaler,
321+
is_anneal=is_anneal and not resume_anneal,
308322
)
309-
for _ in range(start_epoch * ipe):
310-
scheduler.step()
311-
wd_scheduler.step()
312-
next(momentum_scheduler)
313-
mask_collator.step()
323+
if not is_anneal or resume_anneal:
324+
for _ in range(start_epoch * ipe):
325+
scheduler.step()
326+
wd_scheduler.step()
327+
next(momentum_scheduler)
328+
mask_collator.step()
314329

315330
def save_checkpoint(epoch, path):
316331
if rank != 0:

app/vjepa/utils.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import src.models.predictor as vit_pred
1414
import src.models.vision_transformer as video_vit
1515
from src.utils.checkpoint_loader import robust_checkpoint_loader
16-
from src.utils.schedulers import CosineWDSchedule, WarmupCosineSchedule
16+
from src.utils.schedulers import CosineWDSchedule, LinearDecaySchedule, WarmupCosineSchedule
1717
from src.utils.wrappers import MultiSeqWrapper, PredictorMultiSeqWrapper
1818

1919
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
@@ -94,11 +94,14 @@ def load_checkpoint(
9494
target_encoder,
9595
opt,
9696
scaler,
97+
is_anneal=False,
9798
):
9899
logger.info(f"Loading checkpoint from {r_path}")
99100
checkpoint = robust_checkpoint_loader(r_path, map_location=torch.device("cpu"))
100101

101-
epoch = checkpoint["epoch"]
102+
epoch = 0
103+
if not is_anneal:
104+
epoch = checkpoint["epoch"]
102105

103106
# -- loading encoder
104107
pretrained_dict = checkpoint["encoder"]
@@ -205,6 +208,7 @@ def count_parameters(model):
205208

206209

207210
def init_opt(
211+
is_anneal,
208212
encoder,
209213
predictor,
210214
iterations_per_epoch,
@@ -237,14 +241,22 @@ def init_opt(
237241
]
238242

239243
optimizer = torch.optim.AdamW(param_groups, betas=betas, eps=eps)
240-
scheduler = WarmupCosineSchedule(
241-
optimizer,
242-
warmup_steps=int(warmup * iterations_per_epoch),
243-
start_lr=start_lr,
244-
ref_lr=ref_lr,
245-
final_lr=final_lr,
246-
T_max=int(ipe_scale * num_epochs * iterations_per_epoch),
247-
)
244+
if not is_anneal:
245+
scheduler = WarmupCosineSchedule(
246+
optimizer,
247+
warmup_steps=int(warmup * iterations_per_epoch),
248+
start_lr=start_lr,
249+
ref_lr=ref_lr,
250+
final_lr=final_lr,
251+
T_max=int(ipe_scale * num_epochs * iterations_per_epoch),
252+
)
253+
else:
254+
scheduler = LinearDecaySchedule(
255+
optimizer,
256+
ref_lr=ref_lr,
257+
final_lr=final_lr,
258+
T_max=int(ipe_scale * num_epochs * iterations_per_epoch),
259+
)
248260
wd_scheduler = CosineWDSchedule(
249261
optimizer,
250262
ref_wd=wd,

src/utils/schedulers.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,22 @@ def step(self):
9191
if ("WD_exclude" not in group) or not group["WD_exclude"]:
9292
group["weight_decay"] = new_wd
9393
return new_wd
94+
95+
96+
class LinearDecaySchedule(object):
97+
98+
def __init__(self, optimizer, ref_lr, T_max, last_epoch=-1, final_lr=0.0):
99+
self.optimizer = optimizer
100+
self.ref_lr = ref_lr
101+
self.final_lr = final_lr
102+
self.T_max = T_max
103+
self._step = 0.0
104+
105+
def step(self):
106+
self._step += 1
107+
progress = float(self._step) / float(max(1, self.T_max))
108+
new_lr = self.ref_lr + progress * (self.final_lr - self.ref_lr)
109+
for group in self.optimizer.param_groups:
110+
group["lr"] = new_lr
111+
112+
return new_lr

0 commit comments

Comments
 (0)