@@ -65,7 +65,7 @@ def __init__(self, network_info, launcher, models_args, delayed_model_loading=Fa
65
65
def create_pipeline (self , launcher , netowrk_info = None ):
66
66
tokenizer_config = self .config .get ("tokenizer_id" , "openai/clip-vit-large-patch14" )
67
67
tokenizer = AutoTokenizer .from_pretrained (tokenizer_config )
68
- scheduler_config = self .config .get ("sheduler_config " , {})
68
+ scheduler_config = self .config .get ("scheduler_config " , {})
69
69
scheduler = LMSDiscreteScheduler .from_config (scheduler_config )
70
70
netowrk_info = netowrk_info or self .network_info
71
71
self .pipe = OVStableDiffusionPipeline (
@@ -164,7 +164,7 @@ def __init__(
164
164
self ,
165
165
launcher : "BaseLauncher" , # noqa: F821
166
166
tokenizer : "CLIPTokenizer" , # noqa: F821
167
- scheduler : Union ["DDIMScheduler" , "PNDMScheduler" , " LMSDiscreteScheduler" ], # noqa: F821
167
+ scheduler : Union ["LMSDiscreteScheduler" ], # noqa: F821
168
168
model_info : Dict ,
169
169
seed = None ,
170
170
num_inference_steps = 50
@@ -216,30 +216,6 @@ def reset_compiled_models(self):
216
216
self .vae_decoder = None
217
217
self .vae_encoder = None
218
218
219
- def get_w_embedding (self , w , embedding_dim = 512 , dtype = torch .float32 ):
220
- """
221
- see https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
222
- Args:
223
- timesteps: torch.Tensor: generate embedding vectors at these timesteps
224
- embedding_dim: int: dimension of the embeddings to generate
225
- dtype: data type of the generated embeddings
226
- Returns:
227
- embedding vectors with shape `(len(timesteps), embedding_dim)`
228
- """
229
- assert len (w .shape ) == 1
230
- w = w * 1000.0
231
-
232
- half_dim = embedding_dim // 2
233
- emb = torch .log (torch .tensor (10000.0 )) / (half_dim - 1 )
234
- emb = torch .exp (torch .arange (half_dim , dtype = dtype ) * - emb )
235
- emb = w .to (dtype )[:, None ] * emb [None , :]
236
- emb = torch .cat ([torch .sin (emb ), torch .cos (emb )], dim = 1 )
237
- if embedding_dim % 2 == 1 : # zero pad
238
- emb = torch .nn .functional .pad (emb , (0 , 1 ))
239
- assert emb .shape == (w .shape [0 ], embedding_dim )
240
- return emb
241
-
242
-
243
219
def __call__ (
244
220
self ,
245
221
prompt : Union [str , List [str ]],
@@ -291,16 +267,12 @@ def __call__(
291
267
latent_model_input = np .concatenate ([latents ] * 2 ) if do_classifier_free_guidance else latents
292
268
latent_model_input = self .scheduler .scale_model_input (latent_model_input , t )
293
269
294
- inputs = {
295
- "sample" : latent_model_input ,
296
- "timestep" : np .array (t , dtype = np .float32 ),
297
- "encoder_hidden_states" : text_embeddings ,
298
- }
270
+ inputs = [ latent_model_input , np .array (t , dtype = np .float32 ), text_embeddings ]
299
271
if is_extra_input :
300
- inputs [ "timestep_cond" ] = w_embedding
272
+ inputs . append ( w_embedding )
301
273
302
274
# predict the noise residual
303
- noise_pred = self .unet ([ v for v in inputs . values ()] )[self ._unet_output ]
275
+ noise_pred = self .unet (inputs )[self ._unet_output ]
304
276
# perform guidance
305
277
if do_classifier_free_guidance :
306
278
noise_pred_uncond , noise_pred_text = noise_pred [0 ], noise_pred [1 ]
@@ -493,3 +465,27 @@ def print_input_output_info(self):
493
465
model = getattr (self , part_model_id , None )
494
466
if model is not None :
495
467
self .launcher .print_input_output_info (model , part )
468
+
469
+ @staticmethod
470
+ def get_w_embedding (w , embedding_dim = 512 , dtype = torch .float32 ):
471
+ """
472
+ see https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
473
+ Args:
474
+ timesteps: torch.Tensor: generate embedding vectors at these timesteps
475
+ embedding_dim: int: dimension of the embeddings to generate
476
+ dtype: data type of the generated embeddings
477
+ Returns:
478
+ embedding vectors with shape `(len(timesteps), embedding_dim)`
479
+ """
480
+ assert len (w .shape ) == 1
481
+ w = w * 1000.0
482
+
483
+ half_dim = embedding_dim // 2
484
+ emb = torch .log (torch .tensor (10000.0 )) / (half_dim - 1 )
485
+ emb = torch .exp (torch .arange (half_dim , dtype = dtype ) * - emb )
486
+ emb = w .to (dtype )[:, None ] * emb [None , :]
487
+ emb = torch .cat ([torch .sin (emb ), torch .cos (emb )], dim = 1 )
488
+ if embedding_dim % 2 == 1 : # zero pad
489
+ emb = torch .nn .functional .pad (emb , (0 , 1 ))
490
+ assert emb .shape == (w .shape [0 ], embedding_dim )
491
+ return emb
0 commit comments