Skip to content

Commit e3b9999

Browse files
authored
Small fixes
1 parent b343daa commit e3b9999

File tree

1 file changed

+29
-33
lines changed

1 file changed

+29
-33
lines changed

tools/accuracy_checker/accuracy_checker/evaluators/custom_evaluators/stable_diffusion_evaluator.py

+29-33
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def __init__(self, network_info, launcher, models_args, delayed_model_loading=Fa
6565
def create_pipeline(self, launcher, netowrk_info=None):
6666
tokenizer_config = self.config.get("tokenizer_id", "openai/clip-vit-large-patch14")
6767
tokenizer = AutoTokenizer.from_pretrained(tokenizer_config)
68-
scheduler_config = self.config.get("sheduler_config", {})
68+
scheduler_config = self.config.get("scheduler_config", {})
6969
scheduler = LMSDiscreteScheduler.from_config(scheduler_config)
7070
netowrk_info = netowrk_info or self.network_info
7171
self.pipe = OVStableDiffusionPipeline(
@@ -164,7 +164,7 @@ def __init__(
164164
self,
165165
launcher: "BaseLauncher", # noqa: F821
166166
tokenizer: "CLIPTokenizer", # noqa: F821
167-
scheduler: Union["DDIMScheduler", "PNDMScheduler", "LMSDiscreteScheduler"], # noqa: F821
167+
scheduler: Union["LMSDiscreteScheduler"], # noqa: F821
168168
model_info: Dict,
169169
seed = None,
170170
num_inference_steps = 50
@@ -216,30 +216,6 @@ def reset_compiled_models(self):
216216
self.vae_decoder = None
217217
self.vae_encoder = None
218218

219-
def get_w_embedding(self, w, embedding_dim=512, dtype=torch.float32):
220-
"""
221-
see https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
222-
Args:
223-
timesteps: torch.Tensor: generate embedding vectors at these timesteps
224-
embedding_dim: int: dimension of the embeddings to generate
225-
dtype: data type of the generated embeddings
226-
Returns:
227-
embedding vectors with shape `(len(timesteps), embedding_dim)`
228-
"""
229-
assert len(w.shape) == 1
230-
w = w * 1000.0
231-
232-
half_dim = embedding_dim // 2
233-
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
234-
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
235-
emb = w.to(dtype)[:, None] * emb[None, :]
236-
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
237-
if embedding_dim % 2 == 1: # zero pad
238-
emb = torch.nn.functional.pad(emb, (0, 1))
239-
assert emb.shape == (w.shape[0], embedding_dim)
240-
return emb
241-
242-
243219
def __call__(
244220
self,
245221
prompt: Union[str, List[str]],
@@ -291,16 +267,12 @@ def __call__(
291267
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
292268
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
293269

294-
inputs = {
295-
"sample": latent_model_input,
296-
"timestep": np.array(t, dtype=np.float32),
297-
"encoder_hidden_states": text_embeddings,
298-
}
270+
inputs = [ latent_model_input, np.array(t, dtype=np.float32), text_embeddings ]
299271
if is_extra_input:
300-
inputs["timestep_cond"] = w_embedding
272+
inputs.append(w_embedding)
301273

302274
# predict the noise residual
303-
noise_pred = self.unet([v for v in inputs.values()])[self._unet_output]
275+
noise_pred = self.unet(inputs)[self._unet_output]
304276
# perform guidance
305277
if do_classifier_free_guidance:
306278
noise_pred_uncond, noise_pred_text = noise_pred[0], noise_pred[1]
@@ -493,3 +465,27 @@ def print_input_output_info(self):
493465
model = getattr(self, part_model_id, None)
494466
if model is not None:
495467
self.launcher.print_input_output_info(model, part)
468+
469+
@staticmethod
470+
def get_w_embedding(w, embedding_dim=512, dtype=torch.float32):
471+
"""
472+
see https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
473+
Args:
474+
timesteps: torch.Tensor: generate embedding vectors at these timesteps
475+
embedding_dim: int: dimension of the embeddings to generate
476+
dtype: data type of the generated embeddings
477+
Returns:
478+
embedding vectors with shape `(len(timesteps), embedding_dim)`
479+
"""
480+
assert len(w.shape) == 1
481+
w = w * 1000.0
482+
483+
half_dim = embedding_dim // 2
484+
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
485+
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
486+
emb = w.to(dtype)[:, None] * emb[None, :]
487+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
488+
if embedding_dim % 2 == 1: # zero pad
489+
emb = torch.nn.functional.pad(emb, (0, 1))
490+
assert emb.shape == (w.shape[0], embedding_dim)
491+
return emb

0 commit comments

Comments
 (0)