-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtrain.py
More file actions
733 lines (627 loc) · 27.5 KB
/
train.py
File metadata and controls
733 lines (627 loc) · 27.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
import argparse
import copy
from copy import deepcopy
import logging
import os
from pathlib import Path
from collections import OrderedDict
import json
import yaml
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from tqdm.auto import tqdm
from torch.utils.data import DataLoader
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
# Import all model registries
from models.sit import SiT_models as _sit_models
from models.sit_dfm import SiT_models as _dfm_models
from models.sit_elit import SiT_models as _elit_models
from loss import SILoss, DFMSILoss
from utils import load_encoders
from dataset import CustomDataset
from diffusers.models import AutoencoderKL
import wandb
import math
from torchvision.utils import make_grid
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torchvision.transforms import Normalize
logger = get_logger(__name__)
# Combined model registry
ALL_MODELS = {}
ALL_MODELS.update(_sit_models)
ALL_MODELS.update(_dfm_models)
ALL_MODELS.update(_elit_models)
CLIP_DEFAULT_MEAN = (0.48145466, 0.4578275, 0.40821073)
CLIP_DEFAULT_STD = (0.26862954, 0.26130258, 0.27577711)
def get_model_type(model_name):
"""Determine model type from model name prefix."""
if model_name.startswith('DFM-'):
return 'dfm'
elif model_name.startswith('ELIT-'):
return 'elit'
return 'sit'
def preprocess_raw_image(x, enc_type):
resolution = x.shape[-1]
if 'clip' in enc_type:
x = x / 255.
x = torch.nn.functional.interpolate(x, 224 * (resolution // 256), mode='bicubic')
x = Normalize(CLIP_DEFAULT_MEAN, CLIP_DEFAULT_STD)(x)
elif 'mocov3' in enc_type or 'mae' in enc_type:
x = x / 255.
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
elif 'dinov2' in enc_type:
x = x / 255.
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
x = torch.nn.functional.interpolate(x, 224 * (resolution // 256), mode='bicubic')
elif 'dinov1' in enc_type:
x = x / 255.
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
elif 'jepa' in enc_type:
x = x / 255.
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
x = torch.nn.functional.interpolate(x, 224 * (resolution // 256), mode='bicubic')
return x
def array2grid(x):
nrow = round(math.sqrt(x.size(0)))
x = make_grid(x.clamp(0, 1), nrow=nrow, value_range=(0, 1))
x = x.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
return x
@torch.no_grad()
def sample_posterior(moments, latents_scale=1., latents_bias=0.):
device = moments.device
mean, std = torch.chunk(moments, 2, dim=1)
z = mean + std * torch.randn_like(mean)
z = (z * latents_scale + latents_bias)
return z
@torch.no_grad()
def update_ema(ema_model, model, decay=0.9999):
"""
Step the EMA model towards the current model.
"""
ema_params = OrderedDict(ema_model.named_parameters())
model_params = OrderedDict(model.named_parameters())
for name, param in model_params.items():
name = name.replace("module.", "")
ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
def create_logger(logging_dir):
"""
Create a logger that writes to a log file and stdout.
"""
logging.basicConfig(
level=logging.INFO,
format='[\033[34m%(asctime)s\033[0m] %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
)
logger = logging.getLogger(__name__)
return logger
def requires_grad(model, flag=True):
"""
Set requires_grad flag for all parameters in a model.
"""
for p in model.parameters():
p.requires_grad = flag
def sample_sit(model, xT, ys, args, vae, latents_scale, latents_bias, accelerator):
"""Sample from a standard SiT or ELIT model."""
from samplers import euler_sampler
with torch.no_grad():
samples = euler_sampler(
model,
xT,
ys,
num_steps=50,
cfg_scale=4.0,
guidance_low=0.,
guidance_high=1.,
path_type=args.path_type,
heun=False,
).to(torch.float32)
samples = vae.decode((samples - latents_bias) / latents_scale).sample
samples = (samples + 1) / 2.
return {"samples": samples}
def sample_dfm(model, multiscale_xT, ys, args, vae, latents_scale, latents_bias,
accelerator, laplacian_decomposer, multiscale_gt_xs):
"""Sample from a DFM model with multiscale sampling."""
from dfm_utils.samplers_dfm import dfm_euler_sampler
with torch.no_grad():
multiscale_samples = dfm_euler_sampler(
model,
multiscale_xT,
ys,
cfg_scale=4.0,
guidance_low=0.,
guidance_high=1.,
path_type=args.path_type,
heun=False,
num_steps_per_scale=args.num_steps_per_scale,
stage_thresholds=args.stage_sampling_thresholds,
)
log_dict = {}
# Log per-stage outputs
for stage_idx in range(args.stages_count):
stage_sample = multiscale_samples[stage_idx].to(torch.float32)
stage_decoded = vae.decode((stage_sample - latents_bias) / latents_scale).sample
stage_decoded = (stage_decoded + 1) / 2.
stage_decoded_gathered = accelerator.gather(stage_decoded)
log_dict[f"generated_stage_{stage_idx}"] = wandb.Image(array2grid(stage_decoded_gathered))
# Log ground truth stages
for stage_idx in range(args.stages_count):
stage_gt = multiscale_gt_xs[stage_idx].to(torch.float32)
stage_gt_decoded = vae.decode((stage_gt - latents_bias) / latents_scale).sample
stage_gt_decoded = (stage_gt_decoded + 1) / 2.
stage_gt_decoded_gathered = accelerator.gather(stage_gt_decoded)
log_dict[f"gt_stage_{stage_idx}"] = wandb.Image(array2grid(stage_gt_decoded_gathered))
# Recompose
samples = laplacian_decomposer.recompose(multiscale_samples).to(torch.float32)
samples = vae.decode((samples - latents_bias) / latents_scale).sample
samples = (samples + 1) / 2.
return {"samples": samples, "extra_log": log_dict}
def main(args):
model_type = get_model_type(args.model)
# Set accelerator
logging_dir = Path(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(
project_dir=args.output_dir, logging_dir=logging_dir
)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=args.report_to,
project_config=accelerator_project_config,
)
if accelerator.is_main_process:
os.makedirs(args.output_dir, exist_ok=True)
save_dir = os.path.join(args.output_dir, args.exp_name)
os.makedirs(save_dir, exist_ok=True)
args_dict = vars(args)
json_dir = os.path.join(save_dir, "args.json")
with open(json_dir, 'w') as f:
json.dump(args_dict, f, indent=4)
checkpoint_dir = f"{save_dir}/checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)
logger = create_logger(save_dir)
logger.info(f"Experiment directory created at {save_dir}")
logger.info(f"Model type: {model_type}")
device = accelerator.device
if torch.backends.mps.is_available():
accelerator.native_amp = False
if args.seed is not None:
set_seed(args.seed + accelerator.process_index)
# Create model:
assert args.resolution % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)."
latent_size = args.resolution // 8
# Determine whether REPA is enabled
enable_repa = (args.proj_coeff > 0 and args.enc_type.lower() not in ('none', ''))
# Load encoders only when REPA is enabled
encoders, encoder_types, architectures = [], [], []
z_dims = None
if enable_repa:
encoders, encoder_types, architectures = load_encoders(
args.enc_type, device, args.resolution
)
z_dims = [encoder.embed_dim for encoder in encoders]
if accelerator.is_main_process:
logger.info(f"REPA enabled: enc_type={args.enc_type}, proj_coeff={args.proj_coeff}, z_dims={z_dims}")
else:
if accelerator.is_main_process:
logger.info("REPA disabled (no encoder loaded, no projectors created)")
block_kwargs = {"fused_attn": args.fused_attn, "qk_norm": args.qk_norm}
# Build model with type-specific kwargs
model_kwargs = dict(
input_size=latent_size,
num_classes=args.num_classes,
use_cfg=(args.cfg_prob > 0),
enable_repa=enable_repa,
z_dims=z_dims,
encoder_depth=args.encoder_depth,
**block_kwargs,
)
if model_type == 'elit':
model_kwargs.update(dict(
enable_elit=True,
elit_max_mask_prob=args.elit_max_mask_prob,
elit_min_mask_prob=args.elit_min_mask_prob,
group_size=args.elit_group_size,
))
model = ALL_MODELS[args.model](**model_kwargs)
model = model.to(device)
ema = deepcopy(model).to(device)
vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-mse").to(device)
requires_grad(ema, False)
latents_scale = torch.tensor(
[0.18215, 0.18215, 0.18215, 0.18215]
).view(1, 4, 1, 1).to(device)
latents_bias = torch.tensor(
[0., 0., 0., 0.]
).view(1, 4, 1, 1).to(device)
# DFM-specific: Laplacian decomposer
laplacian_decomposer = None
if model_type == 'dfm':
from dfm_utils.laplacian_decomposer import LaplacianDecomposer2D
laplacian_decomposer = LaplacianDecomposer2D(
stages_count=args.stages_count
)
# Create loss function
if model_type == 'dfm':
loss_fn = DFMSILoss(
prediction=args.prediction,
path_type=args.path_type,
accelerator=accelerator,
latents_scale=latents_scale,
latents_bias=latents_bias,
weighting=args.weighting,
num_stages=args.stages_count,
stage_weights=args.stage_weights,
)
else:
loss_fn = SILoss(
prediction=args.prediction,
path_type=args.path_type,
accelerator=accelerator,
latents_scale=latents_scale,
latents_bias=latents_bias,
weighting=args.weighting
)
if accelerator.is_main_process:
logger.info(f"SiT Parameters: {sum(p.numel() for p in model.parameters()):,}")
# Setup optimizer:
if args.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
optimizer = torch.optim.AdamW(
model.parameters(),
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
# Setup data:
train_dataset = CustomDataset(args.data_dir)
local_batch_size = int(args.batch_size // accelerator.num_processes)
train_dataloader = DataLoader(
train_dataset,
batch_size=local_batch_size,
shuffle=True,
num_workers=args.num_workers,
pin_memory=True,
drop_last=True
)
if accelerator.is_main_process:
logger.info(f"Dataset contains {len(train_dataset):,} images ({args.data_dir})")
# Prepare models for training:
update_ema(ema, model, decay=0)
model.train()
ema.eval()
# Resume:
global_step = 0
resume_step = args.resume_step
# Auto-resume: find latest checkpoint if resume_step is -1
if resume_step == -1:
checkpoint_dir_path = f'{os.path.join(args.output_dir, args.exp_name)}/checkpoints'
if os.path.exists(checkpoint_dir_path):
checkpoints = [f for f in os.listdir(checkpoint_dir_path) if f.endswith('.pt')]
if checkpoints:
steps = [int(f.replace('.pt', '')) for f in checkpoints]
resume_step = max(steps)
if accelerator.is_main_process:
logger.info(f"Auto-resume: Found latest checkpoint at step {resume_step}")
else:
resume_step = 0
else:
resume_step = 0
if resume_step > 0:
ckpt_name = str(resume_step).zfill(7) +'.pt'
ckpt = torch.load(
f'{os.path.join(args.output_dir, args.exp_name)}/checkpoints/{ckpt_name}',
map_location='cpu',
weights_only=False,
)
model.load_state_dict(ckpt['model'])
ema.load_state_dict(ckpt['ema'])
optimizer.load_state_dict(ckpt['opt'])
global_step = ckpt['steps']
if accelerator.is_main_process:
logger.info(f"Resumed training from checkpoint: {ckpt_name} (step {global_step})")
model, optimizer, train_dataloader = accelerator.prepare(
model, optimizer, train_dataloader
)
if accelerator.is_main_process:
tracker_config = vars(copy.deepcopy(args))
accelerator.init_trackers(
project_name="REPA",
config=tracker_config,
init_kwargs={
"wandb": {"name": f"{args.exp_name}"}
},
)
progress_bar = tqdm(
range(0, args.max_train_steps),
initial=global_step,
desc="Steps",
disable=not accelerator.is_local_main_process,
)
# Setup sampling data
sample_batch_size = 64 // accelerator.num_processes
gt_raw_images, gt_xs, _ = next(iter(train_dataloader))
assert gt_raw_images.shape[-1] == args.resolution
gt_xs = gt_xs[:sample_batch_size]
gt_xs = sample_posterior(
gt_xs.to(device), latents_scale=latents_scale, latents_bias=latents_bias
)
ys = torch.randint(1000, size=(sample_batch_size,), device=device)
ys = ys.to(device)
n = ys.size(0)
# Create sampling noise (model-type specific)
if model_type == 'dfm':
multiscale_gt_xs = laplacian_decomposer.decompose(gt_xs)
multiscale_xT = {}
for stage in range(args.stages_count):
stage_latent_size = latent_size // (2 ** (args.stages_count - 1 - stage))
multiscale_xT[stage] = torch.randn((n, 4, stage_latent_size, stage_latent_size), device=device)
else:
xT = torch.randn((n, 4, latent_size, latent_size), device=device)
for epoch in range(args.epochs):
model.train()
for raw_image, x, y in train_dataloader:
raw_image = raw_image.to(device)
x = x.squeeze(dim=1).to(device)
y = y.to(device)
if args.legacy:
drop_ids = torch.rand(y.shape[0], device=y.device) < args.cfg_prob
labels = torch.where(drop_ids, args.num_classes, y)
else:
labels = y
with torch.no_grad():
x = sample_posterior(x, latents_scale=latents_scale, latents_bias=latents_bias)
# Compute encoder features only when REPA is enabled
zs = None
if enable_repa:
zs = []
with accelerator.autocast():
for encoder, encoder_type, arch in zip(encoders, encoder_types, architectures):
raw_image_ = preprocess_raw_image(raw_image, encoder_type)
z = encoder.forward_features(raw_image_)
if 'mocov3' in encoder_type: z = z = z[:, 1:]
if 'dinov2' in encoder_type: z = z['x_norm_patchtokens']
zs.append(z)
with accelerator.accumulate(model):
model_kwargs = dict(y=labels)
loss, proj_loss = loss_fn(model, x, model_kwargs, zs=zs)
loss_mean = loss.mean()
proj_loss_mean = proj_loss.mean() if isinstance(proj_loss, torch.Tensor) else proj_loss
loss = loss_mean + proj_loss_mean * args.proj_coeff
## optimization
accelerator.backward(loss)
if accelerator.sync_gradients:
params_to_clip = model.parameters()
grad_norm = accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
optimizer.zero_grad(set_to_none=True)
if accelerator.sync_gradients:
update_ema(ema, model)
### enter
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
if global_step % args.checkpointing_steps == 0 and global_step > 0:
if accelerator.is_main_process:
checkpoint = {
"model": model.module.state_dict(),
"ema": ema.state_dict(),
"opt": optimizer.state_dict(),
"args": args,
"steps": global_step,
}
checkpoint_path = f"{checkpoint_dir}/{global_step:07d}.pt"
torch.save(checkpoint, checkpoint_path)
logger.info(f"Saved checkpoint to {checkpoint_path}")
if (global_step == 1 or (global_step % args.sampling_steps == 0 and global_step > 0)):
if model_type == 'dfm':
sample_result = sample_dfm(
model, multiscale_xT, ys, args, vae,
latents_scale, latents_bias, accelerator,
laplacian_decomposer, multiscale_gt_xs
)
out_samples = accelerator.gather(sample_result["samples"].to(torch.float32))
gt_samples_decoded = vae.decode((gt_xs - latents_bias) / latents_scale).sample
gt_samples_decoded = (gt_samples_decoded + 1) / 2.
gt_samples = accelerator.gather(gt_samples_decoded.to(torch.float32))
log_dict = sample_result.get("extra_log", {})
log_dict["samples_recomposed"] = wandb.Image(array2grid(out_samples))
log_dict["gt_samples_recomposed"] = wandb.Image(array2grid(gt_samples))
accelerator.log(log_dict)
else:
sample_result = sample_sit(
model, xT, ys, args, vae,
latents_scale, latents_bias, accelerator
)
out_samples = accelerator.gather(sample_result["samples"].to(torch.float32))
gt_samples_decoded = vae.decode((gt_xs - latents_bias) / latents_scale).sample
gt_samples_decoded = (gt_samples_decoded + 1) / 2.
gt_samples = accelerator.gather(gt_samples_decoded.to(torch.float32))
accelerator.log({
"samples": wandb.Image(array2grid(out_samples)),
"gt_samples": wandb.Image(array2grid(gt_samples))
})
logging.info("Generating EMA samples done.")
logs = {
"loss": accelerator.gather(loss_mean).mean().detach().item(),
"proj_loss": (accelerator.gather(proj_loss_mean).mean().detach().item()
if isinstance(proj_loss_mean, torch.Tensor) else proj_loss_mean),
"grad_norm": accelerator.gather(grad_norm).mean().detach().item()
}
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
break
if global_step >= args.max_train_steps:
break
model.eval()
accelerator.wait_for_everyone()
if accelerator.is_main_process:
logger.info("Done!")
accelerator.end_training()
DEFAULTS = dict(
# logging
output_dir="exps",
exp_name=None,
logging_dir="logs",
report_to="wandb",
sampling_steps=10000,
resume_step=0,
# model
model=None,
num_classes=1000,
encoder_depth=8,
fused_attn=True,
qk_norm=False,
# dataset
data_dir="../data/imagenet256",
resolution=256,
batch_size=256,
# precision
allow_tf32=False,
mixed_precision="fp16",
# optimization
epochs=1400,
max_train_steps=400000,
checkpointing_steps=10000,
gradient_accumulation_steps=1,
learning_rate=1e-4,
adam_beta1=0.9,
adam_beta2=0.999,
adam_weight_decay=0.0,
adam_epsilon=1e-08,
max_grad_norm=1.0,
# seed
seed=0,
# cpu
num_workers=4,
# loss
path_type="linear",
prediction="v",
cfg_prob=0.1,
enc_type="none",
proj_coeff=0.0,
weighting="uniform",
legacy=False,
# DFM
stages_count=2,
stage_weights=[0.9, 0.1],
stage_sampling_thresholds=[0.1],
num_steps_per_scale=[40, 10],
# ELIT
elit_max_mask_prob=0.0,
elit_min_mask_prob=None,
elit_group_size=4,
)
def build_parser():
"""Build argument parser with all defaults from DEFAULTS dict."""
d = DEFAULTS
parser = argparse.ArgumentParser(
description="Unified Training for SiT, DFM-SiT, and ELIT-SiT.\n"
"Supports YAML config files: --config path/to/config.yaml\n"
"CLI args always override YAML values.",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
# Config file
parser.add_argument("--config", type=str, default=None,
help="Path to YAML config file. CLI args override YAML values.")
# logging
parser.add_argument("--output-dir", type=str, default=d["output_dir"])
parser.add_argument("--exp-name", type=str, default=d["exp_name"])
parser.add_argument("--logging-dir", type=str, default=d["logging_dir"])
parser.add_argument("--report-to", type=str, default=d["report_to"])
parser.add_argument("--sampling-steps", type=int, default=d["sampling_steps"])
parser.add_argument("--resume-step", type=int, default=d["resume_step"])
# model
parser.add_argument("--model", type=str, default=d["model"],
choices=list(ALL_MODELS.keys()),
help="Model name. Prefix determines type: DFM-SiT-*/ELIT-SiT-*/SiT-*")
parser.add_argument("--num-classes", type=int, default=d["num_classes"])
parser.add_argument("--encoder-depth", type=int, default=d["encoder_depth"])
parser.add_argument("--fused-attn", action=argparse.BooleanOptionalAction, default=d["fused_attn"])
parser.add_argument("--qk-norm", action=argparse.BooleanOptionalAction, default=d["qk_norm"])
# dataset
parser.add_argument("--data-dir", type=str, default=d["data_dir"])
parser.add_argument("--resolution", type=int, choices=[256, 512], default=d["resolution"])
parser.add_argument("--batch-size", type=int, default=d["batch_size"])
# precision
parser.add_argument("--allow-tf32", action="store_true", default=d["allow_tf32"])
parser.add_argument("--mixed-precision", type=str, default=d["mixed_precision"],
choices=["no", "fp16", "bf16"])
# optimization
parser.add_argument("--epochs", type=int, default=d["epochs"])
parser.add_argument("--max-train-steps", type=int, default=d["max_train_steps"])
parser.add_argument("--checkpointing-steps", type=int, default=d["checkpointing_steps"])
parser.add_argument("--gradient-accumulation-steps", type=int, default=d["gradient_accumulation_steps"])
parser.add_argument("--learning-rate", type=float, default=d["learning_rate"])
parser.add_argument("--adam-beta1", type=float, default=d["adam_beta1"])
parser.add_argument("--adam-beta2", type=float, default=d["adam_beta2"])
parser.add_argument("--adam-weight-decay", type=float, default=d["adam_weight_decay"])
parser.add_argument("--adam-epsilon", type=float, default=d["adam_epsilon"])
parser.add_argument("--max-grad-norm", type=float, default=d["max_grad_norm"])
# seed
parser.add_argument("--seed", type=int, default=d["seed"])
# cpu
parser.add_argument("--num-workers", type=int, default=d["num_workers"])
# loss
parser.add_argument("--path-type", type=str, default=d["path_type"], choices=["linear", "cosine"])
parser.add_argument("--prediction", type=str, default=d["prediction"], choices=["v"])
parser.add_argument("--cfg-prob", type=float, default=d["cfg_prob"])
parser.add_argument("--enc-type", type=str, default=d["enc_type"])
parser.add_argument("--proj-coeff", type=float, default=d["proj_coeff"])
parser.add_argument("--weighting", type=str, default=d["weighting"])
parser.add_argument("--legacy", action=argparse.BooleanOptionalAction, default=d["legacy"])
# DFM-specific
parser.add_argument("--stages-count", type=int, default=d["stages_count"],
help="[DFM] Number of scales for multiscale decomposition")
parser.add_argument("--stage-weights", type=float, nargs='+', default=d["stage_weights"],
help="[DFM] Weights for different stages to be sampled during training")
parser.add_argument("--stage-sampling-thresholds", type=float, nargs='+',
default=d["stage_sampling_thresholds"],
help="[DFM] Thresholds for stage switching during inference sampling")
parser.add_argument("--num-steps-per-scale", type=int, nargs='+',
default=d["num_steps_per_scale"],
help="[DFM] Number of sampling steps per scale")
# ELIT-specific
parser.add_argument("--elit-max-mask-prob", type=float, default=d["elit_max_mask_prob"],
help="[ELIT] Maximum masking probability for ELIT during training. "
"Use None to sample all valid budgets up to (1 - 1/window_tokens).")
parser.add_argument("--elit-min-mask-prob", type=float, default=d["elit_min_mask_prob"],
help="[ELIT] Minimum masking probability. Defaults to elit_max_mask_prob (single budget). "
"If different from max, mask prob is uniformly sampled from valid levels in [min, max].")
parser.add_argument("--elit-group-size", type=int, default=d["elit_group_size"],
help="[ELIT] Group size for ELIT token masking")
return parser
def _normalize_key(key):
"""Convert between YAML underscore keys and argparse hyphen keys."""
return key.replace('-', '_')
def parse_args(input_args=None):
"""
Parse arguments with YAML + CLI support.
Priority: CLI args > YAML config > defaults.
"""
parser = build_parser()
# First, do a preliminary parse just to get --config
preliminary, _ = parser.parse_known_args(input_args)
if preliminary.config is not None:
with open(preliminary.config, 'r') as f:
yaml_config = yaml.safe_load(f) or {}
yaml_config = {_normalize_key(k): v for k, v in yaml_config.items()}
merged = dict(DEFAULTS)
merged.update(yaml_config)
parser.set_defaults(**merged)
args = parser.parse_args(input_args)
if args.exp_name is None:
parser.error("--exp-name is required (via CLI or YAML config)")
if args.model is None:
parser.error("--model is required (via CLI or YAML config)")
return args
if __name__ == "__main__":
args = parse_args()
main(args)