Skip to content

Commit 0648ba7

Browse files
Wojtek KowalukWojtek Kowaluk
Wojtek Kowaluk
authored and
Wojtek Kowaluk
committed
Fixes to run on CPU and MPS
1 parent a4354c0 commit 0648ba7

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

kandinsky2/kandinsky2_1_model.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ def __init__(
3030
):
3131
self.config = config
3232
self.device = device
33+
if not torch.has_cuda:
34+
self.config["model_config"]["use_fp16"] = False
3335
self.use_fp16 = self.config["model_config"]["use_fp16"]
3436
self.task_type = task_type
3537
self.clip_image_size = config["clip_image_size"]
@@ -54,7 +56,7 @@ def __init__(
5456
clip_mean,
5557
clip_std,
5658
)
57-
self.prior.load_state_dict(torch.load(prior_path), strict=False)
59+
self.prior.load_state_dict(torch.load(prior_path, map_location='cpu'), strict=False)
5860
if self.use_fp16:
5961
self.prior = self.prior.half()
6062
self.text_encoder = TextEncoder(**self.config["text_enc_params"])
@@ -88,7 +90,7 @@ def __init__(
8890

8991
self.config["model_config"]["cache_text_emb"] = True
9092
self.model = create_model(**self.config["model_config"])
91-
self.model.load_state_dict(torch.load(model_path))
93+
self.model.load_state_dict(torch.load(model_path, map_location='cpu'))
9294
if self.use_fp16:
9395
self.model.convert_to_fp16()
9496
self.image_encoder = self.image_encoder.half()

kandinsky2/model/gaussian_diffusion.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -822,7 +822,7 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape):
822822
dimension equal to the length of timesteps.
823823
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
824824
"""
825-
res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
825+
res = th.from_numpy(arr).to(dtype=th.float32).to(device=timesteps.device)[timesteps]
826826
while len(res.shape) < len(broadcast_shape):
827827
res = res[..., None]
828828
return res.expand(broadcast_shape)

0 commit comments

Comments
 (0)