@@ -42,22 +42,38 @@ def __init__(self, model, cond_embed_dim, ref_embed_net, conditioning, nclasses)
4242 self .conditioning = conditioning
4343 self .cond_embed_dim = cond_embed_dim
4444 self .ref_embed_net = ref_embed_net
45+ self .cond_embed_gammas = cond_embed_dim
4546
4647 # Label embedding
4748 if "class" in conditioning :
48- cond_embed_class = cond_embed_dim // 2
49- self .netl_embedder_class = LabelEmbedder (
50- nclasses ,
51- cond_embed_class , # * image_size * image_size
52- )
53- nn .init .normal_ (self .netl_embedder_class .embedding_table .weight , std = 0.02 )
49+ if type (nclasses ) == list :
50+ # TODO this is arbitrary, half for class & half for detector
51+ cond_embed_class = cond_embed_dim // (len (nclasses ) + 1 )
52+ self .netl_embedders_class = nn .ModuleList (
53+ [LabelEmbedder (nc , cond_embed_class ) for nc in nclasses ]
54+ )
55+ for embed in self .netl_embedders_class :
56+ self .cond_embed_gammas -= cond_embed_class
57+ nn .init .normal_ (embed .embedding_table .weight , std = 0.02 )
58+ else :
59+ # TODO this can be included in the general case
60+ cond_embed_class = cond_embed_dim // 2
61+ self .netl_embedder_class = LabelEmbedder (
62+ nclasses ,
63+ cond_embed_class , # * image_size * image_size
64+ )
65+ self .cond_embed_gammas -= cond_embed_class
66+ nn .init .normal_ (
67+ self .netl_embedder_class .embedding_table .weight , std = 0.02
68+ )
5469
5570 if "mask" in conditioning :
5671 cond_embed_mask = cond_embed_dim
5772 self .netl_embedder_mask = LabelEmbedder (
5873 nclasses ,
5974 cond_embed_mask , # * image_size * image_size
6075 )
76+ self .cond_embed_gammas -= cond_embed_class
6177 nn .init .normal_ (self .netl_embedder_mask .embedding_table .weight , std = 0.02 )
6278
6379 # Instantiate model
@@ -90,6 +106,7 @@ def __init__(self, model, cond_embed_dim, ref_embed_net, conditioning, nclasses)
90106 self .emb_layers = nn .Sequential (
91107 torch .nn .SiLU (), nn .Linear (ref_embed_dim , cond_embed_class )
92108 )
109+ self .cond_embed_gammas -= cond_embed_class
93110
94111 def forward (self , input , embed_noise_level , cls , mask , ref ):
95112 cls_embed , mask_embed , ref_embed = self .compute_cond (input , cls , mask , ref )
@@ -114,7 +131,14 @@ def forward(self, input, embed_noise_level, cls, mask, ref):
114131
115132 def compute_cond (self , input , cls , mask , ref ):
116133 if "class" in self .conditioning and cls is not None :
117- cls_embed = self .netl_embedder_class (cls )
134+ if hasattr (self , "netl_embedders_class" ):
135+ cls_embed = []
136+ for i in range (len (self .netl_embedders_class )):
137+ cls_embed .append (self .netl_embedders_class [i ](cls [:, i ]))
138+ cls_embed = torch .cat (cls_embed , dim = 1 )
139+ else :
140+ # TODO general case
141+ cls_embed = self .netl_embedder_class (cls )
118142 else :
119143 cls_embed = None
120144
0 commit comments