Skip to content

Commit b197b50

Browse files
committed
feat(ml): diffusion in latent space
1 parent da1fed1 commit b197b50

File tree

6 files changed

+279
-32
lines changed

6 files changed

+279
-32
lines changed

examples/example_ddpm_noglasses2glasses.json

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
"dropout": false,
3333
"nblocks": 2,
3434
"netE": "resnet_256",
35-
"netG": "unet_mha",
35+
"netG": "hdit",
3636
"ngf": 64,
3737
"norm": "instance",
3838
"padding_type": "reflect",
@@ -132,11 +132,11 @@
132132
"mask_square_B": false,
133133
"rand_mask_A": true
134134
},
135-
"crop_size": 128,
135+
"crop_size": 256,
136136
"dataset_mode": "self_supervised_labeled_mask",
137137
"direction": "BtoA",
138138
"inverted_mask": false,
139-
"load_size": 128,
139+
"load_size": 256,
140140
"max_dataset_size": 1000000000,
141141
"num_threads": 4,
142142
"online_context_pixels": 0,
@@ -167,8 +167,8 @@
167167
"aim_port": 53800,
168168
"aim_server": "http://localhost",
169169
"diff_fake_real": false,
170-
"env": "noglasses2glasses",
171-
"freq": 10000,
170+
"env": "noglasses2glasses_hdit_latent",
171+
"freq": 1000,
172172
"id": 1,
173173
"ncols": 0,
174174
"networks": false,
@@ -235,7 +235,7 @@
235235
"epoch_count": 1,
236236
"export_jit": false,
237237
"gan_mode": "lsgan",
238-
"iter_size": 16,
238+
"iter_size": 8,
239239
"load_iter": 0,
240240
"lr_decay_iters": 50,
241241
"lr_policy": "linear",
@@ -282,12 +282,12 @@
282282
"ddp_port": "12355",
283283
"gpu_ids": "0",
284284
"model_type": "palette",
285-
"name": "noglasses2glasses",
285+
"name": "noglasses2glasses_hdit_latent",
286286
"phase": "train",
287287
"suffix": "",
288288
"test_batch_size": 1,
289289
"warning_mode": false,
290-
"with_amp": false,
291-
"with_tf32": false,
290+
"with_amp": true,
291+
"with_tf32": true,
292292
"with_torch_compile": false
293293
}

models/base_diffusion_model.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,19 @@ def modify_commandline_options_train(parser):
219219
help="the range of probabilities for dropping the canny for each frame",
220220
)
221221

222+
parser.add_argument(
223+
"--alg_diffusion_latent_dc_ae_path",
224+
type=str,
225+
default="",
226+
help="Path to the pretrained DC-AE model for latent space encoding. If empty, this feature is disabled.",
227+
)
228+
parser.add_argument(
229+
"--alg_diffusion_latent_dc_ae_torch_dtype",
230+
type=str,
231+
default="float32",
232+
help="Torch dtype for the DC-AE model.",
233+
)
234+
222235
return parser
223236

224237
def __init__(self, opt, rank):

0 commit comments

Comments
 (0)