Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions app/vjepa/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,11 @@ def main(args, resume_preempt=False):

# -- OPTIMIZATION
cfgs_opt = args.get("optimization")
is_anneal = cfgs_opt.get("is_anneal", False)
anneal_ckpt = cfgs_opt.get("anneal_ckpt", None)
if is_anneal and anneal_ckpt is None:
raise ValueError("Must specify anneal_ckpt if is_anneal is True")
resume_anneal = cfgs_opt.get("resume_anneal", False) or (is_anneal and resume_preempt)
ipe = cfgs_opt.get("ipe", None)
ipe_scale = cfgs_opt.get("ipe_scale", 1.0)
wd = float(cfgs_opt.get("weight_decay"))
Expand Down Expand Up @@ -169,7 +174,14 @@ def main(args, resume_preempt=False):
latest_path = os.path.join(folder, latest_file)
load_path = None
if load_model:
load_path = os.path.join(folder, r_file) if r_file is not None else latest_path
if is_anneal:
if os.path.exists(latest_path) and resume_anneal:
load_path = latest_path
else:
load_path = anneal_ckpt
resume_anneal = False
else:
load_path = r_file if r_file is not None else latest_path
if not os.path.exists(load_path):
load_path = None
load_model = False
Expand Down Expand Up @@ -261,6 +273,7 @@ def main(args, resume_preempt=False):

# -- init optimizer and scheduler
optimizer, scaler, scheduler, wd_scheduler = init_opt(
is_anneal=is_anneal,
encoder=encoder,
predictor=predictor,
wd=wd,
Expand Down Expand Up @@ -305,12 +318,14 @@ def main(args, resume_preempt=False):
target_encoder=target_encoder,
opt=optimizer,
scaler=scaler,
is_anneal=is_anneal and not resume_anneal,
)
for _ in range(start_epoch * ipe):
scheduler.step()
wd_scheduler.step()
next(momentum_scheduler)
mask_collator.step()
if not is_anneal or resume_anneal:
for _ in range(start_epoch * ipe):
scheduler.step()
wd_scheduler.step()
next(momentum_scheduler)
mask_collator.step()

def save_checkpoint(epoch, path):
if rank != 0:
Expand Down
32 changes: 22 additions & 10 deletions app/vjepa/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import src.models.predictor as vit_pred
import src.models.vision_transformer as video_vit
from src.utils.checkpoint_loader import robust_checkpoint_loader
from src.utils.schedulers import CosineWDSchedule, WarmupCosineSchedule
from src.utils.schedulers import CosineWDSchedule, LinearDecaySchedule, WarmupCosineSchedule
from src.utils.wrappers import MultiSeqWrapper, PredictorMultiSeqWrapper

logging.basicConfig(stream=sys.stdout, level=logging.INFO)
Expand Down Expand Up @@ -94,11 +94,14 @@ def load_checkpoint(
target_encoder,
opt,
scaler,
is_anneal=False,
):
logger.info(f"Loading checkpoint from {r_path}")
checkpoint = robust_checkpoint_loader(r_path, map_location=torch.device("cpu"))

epoch = checkpoint["epoch"]
epoch = 0
if not is_anneal:
epoch = checkpoint["epoch"]

# -- loading encoder
pretrained_dict = checkpoint["encoder"]
Expand Down Expand Up @@ -205,6 +208,7 @@ def count_parameters(model):


def init_opt(
is_anneal,
encoder,
predictor,
iterations_per_epoch,
Expand Down Expand Up @@ -237,14 +241,22 @@ def init_opt(
]

optimizer = torch.optim.AdamW(param_groups, betas=betas, eps=eps)
scheduler = WarmupCosineSchedule(
optimizer,
warmup_steps=int(warmup * iterations_per_epoch),
start_lr=start_lr,
ref_lr=ref_lr,
final_lr=final_lr,
T_max=int(ipe_scale * num_epochs * iterations_per_epoch),
)
if not is_anneal:
scheduler = WarmupCosineSchedule(
optimizer,
warmup_steps=int(warmup * iterations_per_epoch),
start_lr=start_lr,
ref_lr=ref_lr,
final_lr=final_lr,
T_max=int(ipe_scale * num_epochs * iterations_per_epoch),
)
else:
scheduler = LinearDecaySchedule(
optimizer,
ref_lr=ref_lr,
final_lr=final_lr,
T_max=int(ipe_scale * num_epochs * iterations_per_epoch),
)
wd_scheduler = CosineWDSchedule(
optimizer,
ref_wd=wd,
Expand Down
19 changes: 19 additions & 0 deletions src/utils/schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,22 @@ def step(self):
if ("WD_exclude" not in group) or not group["WD_exclude"]:
group["weight_decay"] = new_wd
return new_wd


class LinearDecaySchedule(object):

def __init__(self, optimizer, ref_lr, T_max, last_epoch=-1, final_lr=0.0):
self.optimizer = optimizer
self.ref_lr = ref_lr
self.final_lr = final_lr
self.T_max = T_max
self._step = 0.0

def step(self):
self._step += 1
progress = float(self._step) / float(max(1, self.T_max))
new_lr = self.ref_lr + progress * (self.final_lr - self.ref_lr)
for group in self.optimizer.param_groups:
group["lr"] = new_lr

return new_lr