Skip to content

Commit d91a3c6

Browse files
committed
feat: add multiple classes conditionning
1 parent ab18fe1 commit d91a3c6

File tree

6 files changed

+57
-20
lines changed

6 files changed

+57
-20
lines changed

models/modules/diffusion_generator.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,7 @@ def __init__(
6363
self.cond_embed_gammas_in = inner_channel
6464
else:
6565
self.cond_embed_dim = cond_embed_dim
66-
67-
if any(cond in self.denoise_fn.conditioning for cond in ["class", "ref"]):
68-
self.cond_embed_gammas = self.cond_embed_dim // 2
69-
else:
70-
self.cond_embed_gammas = self.cond_embed_dim
66+
self.cond_embed_gammas = self.denoise_fn.cond_embed_gammas
7167

7268
self.cond_embed = nn.Sequential(
7369
nn.Linear(self.cond_embed_gammas, self.cond_embed_gammas),

models/modules/palette_denoise_fn.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,22 +42,38 @@ def __init__(self, model, cond_embed_dim, ref_embed_net, conditioning, nclasses)
4242
self.conditioning = conditioning
4343
self.cond_embed_dim = cond_embed_dim
4444
self.ref_embed_net = ref_embed_net
45+
self.cond_embed_gammas = cond_embed_dim
4546

4647
# Label embedding
4748
if "class" in conditioning:
48-
cond_embed_class = cond_embed_dim // 2
49-
self.netl_embedder_class = LabelEmbedder(
50-
nclasses,
51-
cond_embed_class, # * image_size * image_size
52-
)
53-
nn.init.normal_(self.netl_embedder_class.embedding_table.weight, std=0.02)
49+
if type(nclasses) == list:
50+
# TODO this is arbitrary, half for class & half for detector
51+
cond_embed_class = cond_embed_dim // (len(nclasses) + 1)
52+
self.netl_embedders_class = nn.ModuleList(
53+
[LabelEmbedder(nc, cond_embed_class) for nc in nclasses]
54+
)
55+
for embed in self.netl_embedders_class:
56+
self.cond_embed_gammas -= cond_embed_class
57+
nn.init.normal_(embed.embedding_table.weight, std=0.02)
58+
else:
59+
# TODO this can be included in the general case
60+
cond_embed_class = cond_embed_dim // 2
61+
self.netl_embedder_class = LabelEmbedder(
62+
nclasses,
63+
cond_embed_class, # * image_size * image_size
64+
)
65+
self.cond_embed_gammas -= cond_embed_class
66+
nn.init.normal_(
67+
self.netl_embedder_class.embedding_table.weight, std=0.02
68+
)
5469

5570
if "mask" in conditioning:
5671
cond_embed_mask = cond_embed_dim
5772
self.netl_embedder_mask = LabelEmbedder(
5873
nclasses,
5974
cond_embed_mask, # * image_size * image_size
6075
)
76+
self.cond_embed_gammas -= cond_embed_class
6177
nn.init.normal_(self.netl_embedder_mask.embedding_table.weight, std=0.02)
6278

6379
# Instantiate model
@@ -90,6 +106,7 @@ def __init__(self, model, cond_embed_dim, ref_embed_net, conditioning, nclasses)
90106
self.emb_layers = nn.Sequential(
91107
torch.nn.SiLU(), nn.Linear(ref_embed_dim, cond_embed_class)
92108
)
109+
self.cond_embed_gammas -= cond_embed_class
93110

94111
def forward(self, input, embed_noise_level, cls, mask, ref):
95112
cls_embed, mask_embed, ref_embed = self.compute_cond(input, cls, mask, ref)
@@ -114,7 +131,14 @@ def forward(self, input, embed_noise_level, cls, mask, ref):
114131

115132
def compute_cond(self, input, cls, mask, ref):
116133
if "class" in self.conditioning and cls is not None:
117-
cls_embed = self.netl_embedder_class(cls)
134+
if hasattr(self, "netl_embedders_class"):
135+
cls_embed = []
136+
for i in range(len(self.netl_embedders_class)):
137+
cls_embed.append(self.netl_embedders_class[i](cls[:, i]))
138+
cls_embed = torch.cat(cls_embed, dim=1)
139+
else:
140+
# TODO general case
141+
cls_embed = self.netl_embedder_class(cls)
118142
else:
119143
cls_embed = None
120144

models/palette_model.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,11 @@ def __init__(self, opt, rank):
100100

101101
max_visual_outputs = max(self.opt.train_batch_size, self.opt.num_test_images)
102102

103-
self.num_classes = max(
104-
self.opt.f_s_semantic_nclasses, self.opt.cls_semantic_nclasses
105-
)
103+
# self.num_classes = max(
104+
# self.opt.f_s_semantic_nclasses, self.opt.cls_semantic_nclasses
105+
# )
106+
# TODO decide if we keep cls_semantic_nclasses (not used atm)
107+
self.num_classes = self.opt.f_s_semantic_nclasses
106108

107109
self.use_ref = (
108110
self.opt.alg_diffusion_cond_image_creation == "ref"
@@ -583,10 +585,18 @@ def inference(self, nb_imgs, offset=0):
583585

584586
# task: super resolution, pix2pix
585587
elif self.task in ["super_resolution", "pix2pix"]:
588+
cls = None
589+
590+
if "class" in self.opt.alg_diffusion_cond_embed:
591+
cls = []
592+
for i in self.num_classes:
593+
cls.append(torch.randint_like(self.cls[:, 0], 0, i))
594+
cls = torch.stack(cls, dim=1)
595+
586596
self.output, self.visuals = netG.restoration(
587597
y_cond=self.cond_image[:nb_imgs],
588598
sample_num=self.sample_num,
589-
cls=None,
599+
cls=cls,
590600
)
591601
self.fake_B = self.output
592602

options/common_options.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -509,9 +509,10 @@ def initialize(self, parser):
509509
)
510510
parser.add_argument(
511511
"--f_s_semantic_nclasses",
512-
default=2,
512+
default=[2],
513+
nargs="+",
513514
type=int,
514-
help="number of classes of the semantic loss classifier",
515+
help="number of classes of the semantic loss classifiers",
515516
)
516517
parser.add_argument(
517518
"--f_s_class_weights",

options/inference_diffusion_options.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,8 @@ def initialize(self, parser):
100100
parser.add_argument(
101101
"--cls",
102102
type=int,
103-
default=-1,
103+
nargs="+",
104+
default=[-1],
104105
help="override input bbox classe for generation",
105106
)
106107

scripts/gen_single_image_diffusion.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -606,7 +606,12 @@ def generate(
606606

607607
if opt.model_type == "palette":
608608
if "class" in model.denoise_fn.conditioning:
609-
cls_tensor = torch.ones(1, dtype=torch.int64, device=device) * cls
609+
if len(cls_value) > 1:
610+
cls_tensor = torch.tensor(
611+
cls_value, dtype=torch.int64, device=device
612+
).unsqueeze(0)
613+
else:
614+
cls_tensor = torch.ones(1, dtype=torch.int64, device=device) * cls_value
610615
else:
611616
cls_tensor = None
612617
if ref is not None:

0 commit comments

Comments
 (0)