-
Notifications
You must be signed in to change notification settings - Fork 253
/
Copy pathconfig.py
193 lines (175 loc) · 6 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import json
from dataclasses import asdict, dataclass, field
from typing import Any, Dict, List, Optional
@dataclass
class BaseConfig:
def get(self, attribute_name, default=None):
return getattr(self, attribute_name, default)
def pop(self, attribute_name, default=None):
if hasattr(self, attribute_name):
value = getattr(self, attribute_name)
delattr(self, attribute_name)
return value
else:
return default
def __str__(self):
return json.dumps(asdict(self), indent=4)
@dataclass
class DataConfig(BaseConfig):
data_dir: List[Optional[str]] = field(default_factory=list)
buckets_file: str = "buckets.json"
caption_proportion: Dict[str, int] = field(default_factory=lambda: {"prompt": 1})
external_caption_suffixes: List[str] = field(default_factory=list)
external_clipscore_suffixes: List[str] = field(default_factory=list)
clip_thr_temperature: float = 1.0
clip_thr: float = 0.0
sort_dataset: bool = False
load_text_feat: bool = False
load_vae_feat: bool = False
transform: str = "default_train"
type: str = "SanaWebDatasetMS"
image_size: int = 512
hq_only: bool = False
valid_num: int = 0
data: Any = None
extra: Any = None
@dataclass
class ModelConfig(BaseConfig):
model: str = "SanaMS_600M_P1_D28"
image_size: int = 512
mixed_precision: str = "fp16" # ['fp16', 'fp32', 'bf16']
fp32_attention: bool = True
load_from: Optional[str] = None
resume_from: Optional[Dict[str, Any]] = field(
default_factory=lambda: {
"checkpoint": None,
"load_ema": False,
"resume_lr_scheduler": True,
"resume_optimizer": True,
}
)
aspect_ratio_type: str = "ASPECT_RATIO_1024"
multi_scale: bool = True
pe_interpolation: float = 1.0
micro_condition: bool = False
attn_type: str = "linear"
autocast_linear_attn: bool = False
ffn_type: str = "glumbconv"
mlp_acts: List[Optional[str]] = field(default_factory=lambda: ["silu", "silu", None])
mlp_ratio: float = 2.5
use_pe: bool = False
qk_norm: bool = False
class_dropout_prob: float = 0.0
linear_head_dim: int = 32
cross_norm: bool = False
cfg_scale: int = 4
guidance_type: str = "classifier-free"
pag_applied_layers: List[int] = field(default_factory=lambda: [14])
extra: Any = None
@dataclass
class AEConfig(BaseConfig):
vae_type: str = "dc-ae"
vae_pretrained: str = "mit-han-lab/dc-ae-f32c32-sana-1.0"
scale_factor: float = 0.41407
vae_latent_dim: int = 32
vae_downsample_rate: int = 32
sample_posterior: bool = True
extra: Any = None
@dataclass
class TextEncoderConfig(BaseConfig):
text_encoder_name: str = "gemma-2-2b-it"
caption_channels: int = 2304
y_norm: bool = True
y_norm_scale_factor: float = 1.0
model_max_length: int = 300
chi_prompt: List[Optional[str]] = field(default_factory=lambda: [])
extra: Any = None
@dataclass
class SchedulerConfig(BaseConfig):
train_sampling_steps: int = 1000
predict_v: bool = True
noise_schedule: str = "linear_flow"
pred_sigma: bool = False
learn_sigma: bool = True
vis_sampler: str = "flow_dpm-solver"
flow_shift: float = 1.0
# logit-normal timestep
weighting_scheme: Optional[str] = "logit_normal"
logit_mean: float = 0.0
logit_std: float = 1.0
extra: Any = None
@dataclass
class TrainingConfig(BaseConfig):
num_workers: int = 4
seed: int = 43
train_batch_size: int = 32
num_epochs: int = 100
gradient_accumulation_steps: int = 1
grad_checkpointing: bool = False
gradient_clip: float = 1.0
gc_step: int = 1
optimizer: Dict[str, Any] = field(
default_factory=lambda: {"eps": 1.0e-10, "lr": 0.0001, "type": "AdamW", "weight_decay": 0.03}
)
lr_schedule: str = "constant"
lr_schedule_args: Dict[str, int] = field(default_factory=lambda: {"num_warmup_steps": 500})
auto_lr: Dict[str, str] = field(default_factory=lambda: {"rule": "sqrt"})
ema_rate: float = 0.9999
eval_batch_size: int = 16
use_fsdp: bool = False
use_flash_attn: bool = False
eval_sampling_steps: int = 250
lora_rank: int = 4
log_interval: int = 50
mask_type: str = "null"
mask_loss_coef: float = 0.0
load_mask_index: bool = False
snr_loss: bool = False
real_prompt_ratio: float = 1.0
save_image_epochs: int = 1
save_model_epochs: int = 1
save_model_steps: int = 1000000
visualize: bool = False
null_embed_root: str = "output/pretrained_models/"
valid_prompt_embed_root: str = "output/tmp_embed/"
validation_prompts: List[str] = field(
default_factory=lambda: [
"dog",
"portrait photo of a girl, photograph, highly detailed face, depth of field",
"Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
"A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece",
]
)
local_save_vis: bool = False
deterministic_validation: bool = True
online_metric: bool = False
eval_metric_step: int = 5000
online_metric_dir: str = "metric_helper"
work_dir: str = "/cache/exps/"
skip_step: int = 0
loss_type: str = "huber"
huber_c: float = 0.001
num_ddim_timesteps: int = 50
w_max: float = 15.0
w_min: float = 3.0
ema_decay: float = 0.95
debug_nan: bool = False
extra: Any = None
@dataclass
class SanaConfig(BaseConfig):
data: DataConfig
model: ModelConfig
vae: AEConfig
text_encoder: TextEncoderConfig
scheduler: SchedulerConfig
train: TrainingConfig
work_dir: str = "output/"
resume_from: Optional[str] = None
load_from: Optional[str] = None
debug: bool = False
caching: bool = False
report_to: str = "wandb"
tracker_project_name: str = "t2i-evit-baseline"
name: str = "baseline"
loss_report_name: str = "loss"