Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions models/base_diffusion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,12 @@ def modify_commandline_options_train(parser):
nargs="+",
help="the range of probabilities for dropping the canny for each frame",
)
parser.add_argument(
"--alg_diffusion_ddpm_cm_ft",
action="store_true",
default=False,
help="whether to apply consistency model finetuning from pretrained DDPM",
)

return parser

Expand Down
92 changes: 70 additions & 22 deletions models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,32 +969,80 @@ def load_networks(self, epoch):
state_dict[new_key] = state_dict[key].clone()
del state_dict[key]

state1 = list(state_dict.keys())
state2 = list(net.state_dict().keys())
state1.sort()
state2.sort()

for key1, key2 in zip(state1, state2):
if key1 != key2:
print(key1 == key2, key1, key2)

if hasattr(state_dict, "_ema"):
net.load_state_dict(
state_dict["_ema"], strict=self.opt.model_load_no_strictness
if getattr(self.opt, "alg_diffusion_ddpm_cm_ft", False):
model_dict = net.state_dict()
filtered = {}

for k, v in state_dict.items():
if "denoise_fn.model.cond_embed" in k:
new_k = k.replace(
"denoise_fn.model.cond_embed",
"cm_cond_embed.projection",
)
elif k.startswith("cond_embed."):
new_k = k.replace("cond_embed", "cm_cond_embed.projection")
elif "denoise_fn.model." in k:
new_k = k.replace("denoise_fn.model.", "cm_model.")
else:
new_k = k

if new_k in model_dict and v.shape == model_dict[new_k].shape:
filtered[new_k] = v
else:
if "cond_embed" in k:
print(f"⚠️ unmatched cond_embed key {k}{new_k}")
else:
print(
f"⚠️ skipping {new_k}: shape {v.shape if hasattr(v, 'shape') else 'N/A'}"
)

missing = set(model_dict.keys()) - set(filtered.keys())
extra = set(state_dict.keys()) - set(model_dict.keys())

print(
f"Loaded {len(filtered)}/{len(model_dict)} params; {len(missing)} missing.",
flush=True,
)

if missing:
print("\n⚠️ Missing keys:")
for k in sorted(missing):
print(" ", k)

net.load_state_dict(filtered, strict=False)

print(
"✅ Loaded pretrained DDPM weights (with partial embedding transfer).",
flush=True,
)

else:
if (
name == "G_A"
and hasattr(net, "unet")
and hasattr(net, "vae")
and any("lora" in n for n, _ in net.unet.named_parameters())
):
net.load_lora_config(load_path)
print("loading the lora")
else:
state1 = list(state_dict.keys())
state2 = list(net.state_dict().keys())
state1.sort()
state2.sort()

for key1, key2 in zip(state1, state2):
if key1 != key2:
print(key1 == key2, key1, key2)

if hasattr(state_dict, "_ema"):
net.load_state_dict(
state_dict, strict=self.opt.model_load_no_strictness
state_dict["_ema"], strict=self.opt.model_load_no_strictness
)
else:
if (
name == "G_A"
and hasattr(net, "unet")
and hasattr(net, "vae")
and any("lora" in n for n, _ in net.unet.named_parameters())
):
net.load_lora_config(load_path)
print("loading the lora")
else:
net.load_state_dict(
state_dict, strict=self.opt.model_load_no_strictness
)

def get_nets(self):
return_nets = {}
Expand Down
7 changes: 5 additions & 2 deletions models/cm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,11 @@ def __init__(self, opt, rank):
# Define network
opt.alg_palette_sampling_method = ""
opt.alg_diffusion_cond_embed = opt.alg_diffusion_cond_image_creation
opt.alg_diffusion_cond_embed_dim = 256
self.netG_A = diffusion_networks.define_G(**vars(opt)).to(self.device)
if opt.alg_diffusion_ddpm_cm_ft and self.opt.G_netG == "unet_vid":
opt.alg_diffusion_cond_embed_dim = 32
else:
opt.alg_diffusion_cond_embed_dim = 256
self.netG_A = diffusion_networks.define_G(opt=opt, **vars(opt)).to(self.device)
if opt.isTrain:
self.netG_A.current_t = max(self.netG_A.current_t, opt.total_iters)
else:
Expand Down
8 changes: 7 additions & 1 deletion models/diffusion_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def define_G(
alg_diffusion_cond_embed,
alg_diffusion_cond_embed_dim,
alg_diffusion_ref_embed_net,
alg_diffusion_ddpm_cm_ft,
model_prior_321_backwardcompatibility,
dropout=0,
channel_mults=(1, 2, 4, 8),
Expand All @@ -67,6 +68,7 @@ def define_G(
use_new_attention_order=False,
f_s_semantic_nclasses=-1,
train_feat_wavelet=False,
opt=None,
**unused_options,
):
"""Create a generator
Expand Down Expand Up @@ -95,7 +97,10 @@ def define_G(
if model_type == "palette":
in_channel = model_input_nc + model_output_nc
else: # CM
in_channel = model_input_nc
if alg_diffusion_ddpm_cm_ft:
in_channel = model_input_nc + model_output_nc
else:
in_channel = model_input_nc
if (
alg_diffusion_cond_embed != "" and alg_diffusion_cond_embed != "y_t"
) or alg_diffusion_task == "pix2pix":
Expand Down Expand Up @@ -270,6 +275,7 @@ def define_G(
sampling_method="",
image_size=data_crop_size,
G_ngf=G_ngf,
opt=opt,
)
else:
raise NotImplementedError(model_type + " not implemented")
Expand Down
36 changes: 26 additions & 10 deletions models/modules/cm_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,17 +209,25 @@ def skip_scaling(


class NoiseLevelEmbedding(nn.Module):
def __init__(self, channels: int, scale: float = 0.02) -> None:
def __init__(self, channels: int, opt, scale: float = 0.02) -> None:

super().__init__()

self.W = nn.Parameter(torch.randn(channels // 2) * scale, requires_grad=False)

self.projection = nn.Sequential(
nn.Linear(channels, 4 * channels),
nn.SiLU(),
nn.Linear(4 * channels, channels),
Rearrange("b c -> b c () ()"),
)
if getattr(opt, "alg_diffusion_ddpm_cm_ft", False):
self.projection = nn.Sequential(
nn.Linear(channels, channels),
nn.SiLU(),
nn.Linear(channels, channels),
Rearrange("b c -> b c () ()"),
)
else:
self.projection = nn.Sequential(
nn.Linear(channels, 4 * channels),
nn.SiLU(),
nn.Linear(4 * channels, channels),
Rearrange("b c -> b c () ()"),
)

def forward(self, x: Tensor) -> Tensor:
h = x[:, None] * self.W[None, :] * 2 * torch.pi
Expand All @@ -235,12 +243,14 @@ def __init__(
sampling_method,
image_size,
G_ngf,
opt=None,
):
super().__init__()

self.cm_model = cm_model
self.sampling_method = sampling_method
self.image_size = image_size
self.opt = opt

self.sigma_min = 0.002
self.sigma_max = 80.0
Expand All @@ -254,7 +264,7 @@ def __init__(
self.lognormal_std = 2.0

self.cond_embed_dim = self.cm_model.cond_embed_dim
self.cm_cond_embed = NoiseLevelEmbedding(self.cond_embed_dim)
self.cm_cond_embed = NoiseLevelEmbedding(self.cond_embed_dim, self.opt)

self.current_t = 2 # default value, set from cm_model upon resume

Expand All @@ -273,7 +283,13 @@ def cm_forward(self, x, sigma, sigma_data, sigma_min, x_cond=None):
else:
x_with_cond = torch.cat([x_cond, x], dim=2)
else:
x_with_cond = x
if len(x.shape) != 5:
x_with_cond = x
elif self.opt.G_netG == "unet_vid" and self.opt.alg_diffusion_ddpm_cm_ft:
x_with_cond = torch.cat([x, x], dim=2)
else:
x_with_cond = x

return c_skip * x + c_out * self.cm_model(
x_with_cond, embed_noise_level
) # , **kwargs)
Expand Down