Skip to content

Commit 92ec007

Browse files
committed
feat(ml): fine-tuning consistency model based on the pretrained DDPM
1 parent 6bf4412 commit 92ec007

File tree

5 files changed

+114
-35
lines changed

5 files changed

+114
-35
lines changed

models/base_diffusion_model.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,12 @@ def modify_commandline_options_train(parser):
218218
nargs="+",
219219
help="the range of probabilities for dropping the canny for each frame",
220220
)
221+
parser.add_argument(
222+
"--alg_diffusion_ddpm_cm_ft",
223+
action="store_true",
224+
default=False,
225+
help="whether to apply consistency model finetuning from pretrained DDPM",
226+
)
221227

222228
return parser
223229

models/base_model.py

Lines changed: 70 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -969,32 +969,80 @@ def load_networks(self, epoch):
969969
state_dict[new_key] = state_dict[key].clone()
970970
del state_dict[key]
971971

972-
state1 = list(state_dict.keys())
973-
state2 = list(net.state_dict().keys())
974-
state1.sort()
975-
state2.sort()
976-
977-
for key1, key2 in zip(state1, state2):
978-
if key1 != key2:
979-
print(key1 == key2, key1, key2)
980-
981-
if hasattr(state_dict, "_ema"):
982-
net.load_state_dict(
983-
state_dict["_ema"], strict=self.opt.model_load_no_strictness
972+
if getattr(self.opt, "alg_diffusion_ddpm_cm_ft", False):
973+
model_dict = net.state_dict()
974+
filtered = {}
975+
976+
for k, v in state_dict.items():
977+
if "denoise_fn.model.cond_embed" in k:
978+
new_k = k.replace(
979+
"denoise_fn.model.cond_embed",
980+
"cm_cond_embed.projection",
981+
)
982+
elif k.startswith("cond_embed."):
983+
new_k = k.replace("cond_embed", "cm_cond_embed.projection")
984+
elif "denoise_fn.model." in k:
985+
new_k = k.replace("denoise_fn.model.", "cm_model.")
986+
else:
987+
new_k = k
988+
989+
if new_k in model_dict and v.shape == model_dict[new_k].shape:
990+
filtered[new_k] = v
991+
else:
992+
if "cond_embed" in k:
993+
print(f"⚠️ unmatched cond_embed key {k}{new_k}")
994+
else:
995+
print(
996+
f"⚠️ skipping {new_k}: shape {v.shape if hasattr(v, 'shape') else 'N/A'}"
997+
)
998+
999+
missing = set(model_dict.keys()) - set(filtered.keys())
1000+
extra = set(state_dict.keys()) - set(model_dict.keys())
1001+
1002+
print(
1003+
f"Loaded {len(filtered)}/{len(model_dict)} params; {len(missing)} missing.",
1004+
flush=True,
1005+
)
1006+
1007+
if missing:
1008+
print("\n⚠️ Missing keys:")
1009+
for k in sorted(missing):
1010+
print(" ", k)
1011+
1012+
net.load_state_dict(filtered, strict=False)
1013+
1014+
print(
1015+
"✅ Loaded pretrained DDPM weights (with partial embedding transfer).",
1016+
flush=True,
9841017
)
1018+
9851019
else:
986-
if (
987-
name == "G_A"
988-
and hasattr(net, "unet")
989-
and hasattr(net, "vae")
990-
and any("lora" in n for n, _ in net.unet.named_parameters())
991-
):
992-
net.load_lora_config(load_path)
993-
print("loading the lora")
994-
else:
1020+
state1 = list(state_dict.keys())
1021+
state2 = list(net.state_dict().keys())
1022+
state1.sort()
1023+
state2.sort()
1024+
1025+
for key1, key2 in zip(state1, state2):
1026+
if key1 != key2:
1027+
print(key1 == key2, key1, key2)
1028+
1029+
if hasattr(state_dict, "_ema"):
9951030
net.load_state_dict(
996-
state_dict, strict=self.opt.model_load_no_strictness
1031+
state_dict["_ema"], strict=self.opt.model_load_no_strictness
9971032
)
1033+
else:
1034+
if (
1035+
name == "G_A"
1036+
and hasattr(net, "unet")
1037+
and hasattr(net, "vae")
1038+
and any("lora" in n for n, _ in net.unet.named_parameters())
1039+
):
1040+
net.load_lora_config(load_path)
1041+
print("loading the lora")
1042+
else:
1043+
net.load_state_dict(
1044+
state_dict, strict=self.opt.model_load_no_strictness
1045+
)
9981046

9991047
def get_nets(self):
10001048
return_nets = {}

models/cm_model.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,8 +168,11 @@ def __init__(self, opt, rank):
168168
# Define network
169169
opt.alg_palette_sampling_method = ""
170170
opt.alg_diffusion_cond_embed = opt.alg_diffusion_cond_image_creation
171-
opt.alg_diffusion_cond_embed_dim = 256
172-
self.netG_A = diffusion_networks.define_G(**vars(opt)).to(self.device)
171+
if opt.alg_diffusion_ddpm_cm_ft and self.opt.G_netG == "unet_vid":
172+
opt.alg_diffusion_cond_embed_dim = 32
173+
else:
174+
opt.alg_diffusion_cond_embed_dim = 256
175+
self.netG_A = diffusion_networks.define_G(opt=opt, **vars(opt)).to(self.device)
173176
if opt.isTrain:
174177
self.netG_A.current_t = max(self.netG_A.current_t, opt.total_iters)
175178
else:

models/diffusion_networks.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def define_G(
5555
alg_diffusion_cond_embed,
5656
alg_diffusion_cond_embed_dim,
5757
alg_diffusion_ref_embed_net,
58+
alg_diffusion_ddpm_cm_ft,
5859
model_prior_321_backwardcompatibility,
5960
dropout=0,
6061
channel_mults=(1, 2, 4, 8),
@@ -67,6 +68,7 @@ def define_G(
6768
use_new_attention_order=False,
6869
f_s_semantic_nclasses=-1,
6970
train_feat_wavelet=False,
71+
opt=None,
7072
**unused_options,
7173
):
7274
"""Create a generator
@@ -95,7 +97,10 @@ def define_G(
9597
if model_type == "palette":
9698
in_channel = model_input_nc + model_output_nc
9799
else: # CM
98-
in_channel = model_input_nc
100+
if alg_diffusion_ddpm_cm_ft:
101+
in_channel = model_input_nc + model_output_nc
102+
else:
103+
in_channel = model_input_nc
99104
if (
100105
alg_diffusion_cond_embed != "" and alg_diffusion_cond_embed != "y_t"
101106
) or alg_diffusion_task == "pix2pix":
@@ -270,6 +275,7 @@ def define_G(
270275
sampling_method="",
271276
image_size=data_crop_size,
272277
G_ngf=G_ngf,
278+
opt=opt,
273279
)
274280
else:
275281
raise NotImplementedError(model_type + " not implemented")

models/modules/cm_generator.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -209,17 +209,25 @@ def skip_scaling(
209209

210210

211211
class NoiseLevelEmbedding(nn.Module):
212-
def __init__(self, channels: int, scale: float = 0.02) -> None:
212+
def __init__(self, channels: int, opt, scale: float = 0.02) -> None:
213+
213214
super().__init__()
214215

215216
self.W = nn.Parameter(torch.randn(channels // 2) * scale, requires_grad=False)
216-
217-
self.projection = nn.Sequential(
218-
nn.Linear(channels, 4 * channels),
219-
nn.SiLU(),
220-
nn.Linear(4 * channels, channels),
221-
Rearrange("b c -> b c () ()"),
222-
)
217+
if opt.alg_diffusion_ddpm_cm_ft:
218+
self.projection = nn.Sequential(
219+
nn.Linear(channels, channels),
220+
nn.SiLU(),
221+
nn.Linear(channels, channels),
222+
Rearrange("b c -> b c () ()"),
223+
)
224+
else:
225+
self.projection = nn.Sequential(
226+
nn.Linear(channels, 4 * channels),
227+
nn.SiLU(),
228+
nn.Linear(4 * channels, channels),
229+
Rearrange("b c -> b c () ()"),
230+
)
223231

224232
def forward(self, x: Tensor) -> Tensor:
225233
h = x[:, None] * self.W[None, :] * 2 * torch.pi
@@ -235,12 +243,14 @@ def __init__(
235243
sampling_method,
236244
image_size,
237245
G_ngf,
246+
opt=None,
238247
):
239248
super().__init__()
240249

241250
self.cm_model = cm_model
242251
self.sampling_method = sampling_method
243252
self.image_size = image_size
253+
self.opt = opt
244254

245255
self.sigma_min = 0.002
246256
self.sigma_max = 80.0
@@ -254,7 +264,7 @@ def __init__(
254264
self.lognormal_std = 2.0
255265

256266
self.cond_embed_dim = self.cm_model.cond_embed_dim
257-
self.cm_cond_embed = NoiseLevelEmbedding(self.cond_embed_dim)
267+
self.cm_cond_embed = NoiseLevelEmbedding(self.cond_embed_dim, self.opt)
258268

259269
self.current_t = 2 # default value, set from cm_model upon resume
260270

@@ -273,7 +283,13 @@ def cm_forward(self, x, sigma, sigma_data, sigma_min, x_cond=None):
273283
else:
274284
x_with_cond = torch.cat([x_cond, x], dim=2)
275285
else:
276-
x_with_cond = x
286+
if len(x.shape) != 5:
287+
x_with_cond = x
288+
elif self.opt.G_netG == "unet_vid" and self.opt.alg_diffusion_ddpm_cm_ft:
289+
x_with_cond = torch.cat([x, x], dim=2)
290+
else:
291+
x_with_cond = x
292+
277293
return c_skip * x + c_out * self.cm_model(
278294
x_with_cond, embed_noise_level
279295
) # , **kwargs)

0 commit comments

Comments
 (0)