Skip to content

Commit 021f12a

Browse files
committed
Un-hardcode "cuda" as default device name
Allow configuring with `SGM_DEFAULT_DEVICE`
1 parent 059d8e9 commit 021f12a

7 files changed

Lines changed: 107 additions & 40 deletions

File tree

sgm/inference/api.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
EulerEDMSampler,
1414
HeunEDMSampler,
1515
LinearMultistepSampler)
16-
from sgm.util import load_model_from_config
16+
from sgm.util import load_model_from_config, get_default_device_name
1717

1818

1919
class ModelArchitecture(str, Enum):
@@ -158,7 +158,7 @@ def __init__(
158158
model_id: ModelArchitecture,
159159
model_path="checkpoints",
160160
config_path="configs/inference",
161-
device="cuda",
161+
device: Optional[str] = None,
162162
use_fp16=True,
163163
) -> None:
164164
if model_id not in model_specs:
@@ -167,10 +167,10 @@ def __init__(
167167
self.specs = model_specs[self.model_id]
168168
self.config = str(pathlib.Path(config_path, self.specs.config))
169169
self.ckpt = str(pathlib.Path(model_path, self.specs.ckpt))
170-
self.device = device
170+
self.device = device or get_default_device_name()
171171
self.model = self._load_model(device=device, use_fp16=use_fp16)
172172

173-
def _load_model(self, device="cuda", use_fp16=True):
173+
def _load_model(self, *, device, use_fp16=True):
174174
config = OmegaConf.load(self.config)
175175
model = load_model_from_config(config, self.ckpt)
176176
if model is None:

sgm/inference/helpers.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,12 @@
44

55
import numpy as np
66
import torch
7+
from PIL import Image
78
from einops import rearrange
89
from imwatermark import WatermarkEncoder
910
from omegaconf import ListConfig
10-
from PIL import Image
11-
from torch import autocast
1211

13-
from sgm.util import append_dims
12+
from sgm.util import append_dims, safe_autocast, get_default_device_name
1413

1514

1615
class WatermarkEmbedder:
@@ -111,21 +110,24 @@ def do_sample(
111110
batch2model_input: Optional[List] = None,
112111
return_latents=False,
113112
filter=None,
114-
device="cuda",
113+
device: Optional[str] = None,
115114
):
115+
if not device:
116+
device = get_default_device_name()
116117
if force_uc_zero_embeddings is None:
117118
force_uc_zero_embeddings = []
118119
if batch2model_input is None:
119120
batch2model_input = []
120121

121122
with torch.no_grad():
122-
with autocast(device) as precision_scope:
123+
with safe_autocast(device):
123124
with model.ema_scope():
124125
num_samples = [num_samples]
125126
batch, batch_uc = get_batch(
126127
get_unique_embedder_keys_from_conditioner(model.conditioner),
127128
value_dict,
128129
num_samples,
130+
device=device,
129131
)
130132
for key in batch:
131133
if isinstance(batch[key], torch.Tensor):
@@ -170,7 +172,13 @@ def denoiser(input, sigma, c):
170172
return samples
171173

172174

173-
def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
175+
def get_batch(
176+
keys,
177+
value_dict,
178+
N: Union[List, ListConfig],
179+
*,
180+
device: str,
181+
):
174182
# Hardcoded demo setups; might undergo some changes in the future
175183

176184
batch = {}
@@ -227,7 +235,9 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
227235
return batch, batch_uc
228236

229237

230-
def get_input_image_tensor(image: Image.Image, device="cuda"):
238+
def get_input_image_tensor(image: Image.Image, device: Optional[str] = None):
239+
if not device:
240+
device = get_default_device_name()
231241
w, h = image.size
232242
print(f"loaded input image of size ({w}, {h})")
233243
width, height = map(
@@ -252,15 +262,18 @@ def do_img2img(
252262
return_latents=False,
253263
skip_encode=False,
254264
filter=None,
255-
device="cuda",
265+
device: Optional[str] = None,
256266
):
267+
if not device:
268+
device = get_default_device_name()
257269
with torch.no_grad():
258-
with autocast(device) as precision_scope:
270+
with safe_autocast(device):
259271
with model.ema_scope():
260272
batch, batch_uc = get_batch(
261273
get_unique_embedder_keys_from_conditioner(model.conditioner),
262274
value_dict,
263275
[num_samples],
276+
device=device,
264277
)
265278
c, uc = model.conditioner.get_unconditional_conditioning(
266279
batch,

sgm/models/diffusion.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import math
22
from contextlib import contextmanager
3-
from typing import Any, Dict, List, Optional, Tuple, Union
3+
from typing import Any, Dict, List, Tuple, Union, Optional
44

55
import pytorch_lightning as pl
66
import torch
@@ -12,8 +12,15 @@
1212
from ..modules.autoencoding.temporal_ae import VideoDecoder
1313
from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
1414
from ..modules.ema import LitEma
15-
from ..util import (default, disabled_train, get_obj_from_str,
16-
instantiate_from_config, log_txt_as_img)
15+
from ..util import (
16+
default,
17+
disabled_train,
18+
get_default_device_name,
19+
get_obj_from_str,
20+
instantiate_from_config,
21+
log_txt_as_img,
22+
safe_autocast,
23+
)
1724

1825

1926
class DiffusionEngine(pl.LightningModule):
@@ -114,14 +121,20 @@ def get_input(self, batch):
114121
# image tensors should be scaled to -1 ... 1 and in bchw format
115122
return batch[self.input_key]
116123

124+
def _first_stage_autocast_context(self):
125+
return safe_autocast(
126+
device=get_default_device_name(),
127+
enabled=not self.disable_first_stage_autocast,
128+
)
129+
117130
@torch.no_grad()
118131
def decode_first_stage(self, z):
119132
z = 1.0 / self.scale_factor * z
120133
n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])
121134

122135
n_rounds = math.ceil(z.shape[0] / n_samples)
123136
all_out = []
124-
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
137+
with self._first_stage_autocast_context():
125138
for n in range(n_rounds):
126139
if isinstance(self.first_stage_model.decoder, VideoDecoder):
127140
kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])}
@@ -139,7 +152,7 @@ def encode_first_stage(self, x):
139152
n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0])
140153
n_rounds = math.ceil(x.shape[0] / n_samples)
141154
all_out = []
142-
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
155+
with self._first_stage_autocast_context():
143156
for n in range(n_rounds):
144157
out = self.first_stage_model.encode(
145158
x[n * n_samples : (n + 1) * n_samples]

sgm/modules/diffusionmodules/openaimodel.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,15 @@
1010
from torch.utils.checkpoint import checkpoint
1111

1212
from ...modules.attention import SpatialTransformer
13-
from ...modules.diffusionmodules.util import (avg_pool_nd, conv_nd, linear,
14-
normalization,
15-
timestep_embedding, zero_module)
13+
from ...modules.diffusionmodules.util import (
14+
avg_pool_nd,
15+
checkpoint,
16+
conv_nd,
17+
linear,
18+
normalization,
19+
timestep_embedding,
20+
zero_module,
21+
)
1622
from ...modules.video_attention import SpatialVideoTransformer
1723
from ...util import exists
1824

sgm/modules/diffusionmodules/sampling.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,14 @@
99
from omegaconf import ListConfig, OmegaConf
1010
from tqdm import tqdm
1111

12-
from ...modules.diffusionmodules.sampling_utils import (get_ancestral_step,
13-
linear_multistep_coeff,
14-
to_d, to_neg_log_sigma,
15-
to_sigma)
16-
from ...util import append_dims, default, instantiate_from_config
12+
from ...modules.diffusionmodules.sampling_utils import (
13+
get_ancestral_step,
14+
linear_multistep_coeff,
15+
to_d,
16+
to_neg_log_sigma,
17+
to_sigma,
18+
)
19+
from ...util import append_dims, default, instantiate_from_config, get_default_device_name
1720

1821
DEFAULT_GUIDER = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"}
1922

@@ -25,8 +28,10 @@ def __init__(
2528
num_steps: Union[int, None] = None,
2629
guider_config: Union[Dict, ListConfig, OmegaConf, None] = None,
2730
verbose: bool = False,
28-
device: str = "cuda",
31+
device: Union[str, None] = None,
2932
):
33+
if device is None:
34+
device = get_default_device_name()
3035
self.num_steps = num_steps
3136
self.discretization = instantiate_from_config(discretization_config)
3237
self.guider = instantiate_from_config(

sgm/modules/encoders/modules.py

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,17 @@
2020
from ...modules.diffusionmodules.util import (extract_into_tensor,
2121
make_beta_schedule)
2222
from ...modules.distributions.distributions import DiagonalGaussianDistribution
23-
from ...util import (append_dims, autocast, count_params, default,
24-
disabled_train, expand_dims_like, instantiate_from_config)
23+
from ...util import (
24+
append_dims,
25+
autocast,
26+
count_params,
27+
default,
28+
disabled_train,
29+
expand_dims_like,
30+
get_default_device_name,
31+
instantiate_from_config,
32+
safe_autocast,
33+
)
2534

2635

2736
class AbstractEmbModel(nn.Module):
@@ -225,7 +234,9 @@ def forward(self, c):
225234
c = c[:, None, :]
226235
return c
227236

228-
def get_unconditional_conditioning(self, bs, device="cuda"):
237+
def get_unconditional_conditioning(self, bs, device=None):
238+
if device is None:
239+
device = get_default_device_name()
229240
uc_class = (
230241
self.n_classes - 1
231242
) # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
@@ -250,9 +261,10 @@ class FrozenT5Embedder(AbstractEmbModel):
250261
"""Uses the T5 transformer encoder for text"""
251262

252263
def __init__(
253-
self, version="google/t5-v1_1-xxl", device="cuda", max_length=77, freeze=True
264+
self, version="google/t5-v1_1-xxl", device=None, max_length=77, freeze=True
254265
): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
255266
super().__init__()
267+
device = device or get_default_device_name()
256268
self.tokenizer = T5Tokenizer.from_pretrained(version)
257269
self.transformer = T5EncoderModel.from_pretrained(version)
258270
self.device = device
@@ -277,7 +289,7 @@ def forward(self, text):
277289
return_tensors="pt",
278290
)
279291
tokens = batch_encoding["input_ids"].to(self.device)
280-
with torch.autocast("cuda", enabled=False):
292+
with safe_autocast(get_default_device_name(), enabled=False):
281293
outputs = self.transformer(input_ids=tokens)
282294
z = outputs.last_hidden_state
283295
return z
@@ -292,9 +304,10 @@ class FrozenByT5Embedder(AbstractEmbModel):
292304
"""
293305

294306
def __init__(
295-
self, version="google/byt5-base", device="cuda", max_length=77, freeze=True
307+
self, version="google/byt5-base", device=None, max_length=77, freeze=True
296308
): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
297309
super().__init__()
310+
device = device or get_default_device_name()
298311
self.tokenizer = ByT5Tokenizer.from_pretrained(version)
299312
self.transformer = T5EncoderModel.from_pretrained(version)
300313
self.device = device
@@ -319,7 +332,7 @@ def forward(self, text):
319332
return_tensors="pt",
320333
)
321334
tokens = batch_encoding["input_ids"].to(self.device)
322-
with torch.autocast("cuda", enabled=False):
335+
with safe_autocast(get_default_device_name(), enabled=False):
323336
outputs = self.transformer(input_ids=tokens)
324337
z = outputs.last_hidden_state
325338
return z
@@ -336,14 +349,15 @@ class FrozenCLIPEmbedder(AbstractEmbModel):
336349
def __init__(
337350
self,
338351
version="openai/clip-vit-large-patch14",
339-
device="cuda",
352+
device=None,
340353
max_length=77,
341354
freeze=True,
342355
layer="last",
343356
layer_idx=None,
344357
always_return_pooled=False,
345358
): # clip-vit-base-patch32
346359
super().__init__()
360+
device = device or get_default_device_name()
347361
assert layer in self.LAYERS
348362
self.tokenizer = CLIPTokenizer.from_pretrained(version)
349363
self.transformer = CLIPTextModel.from_pretrained(version)
@@ -404,14 +418,15 @@ def __init__(
404418
self,
405419
arch="ViT-H-14",
406420
version="laion2b_s32b_b79k",
407-
device="cuda",
421+
device=None,
408422
max_length=77,
409423
freeze=True,
410424
layer="last",
411425
always_return_pooled=False,
412426
legacy=True,
413427
):
414428
super().__init__()
429+
device = device or get_default_device_name()
415430
assert layer in self.LAYERS
416431
model, _, _ = open_clip.create_model_and_transforms(
417432
arch,
@@ -506,12 +521,13 @@ def __init__(
506521
self,
507522
arch="ViT-H-14",
508523
version="laion2b_s32b_b79k",
509-
device="cuda",
524+
device=None,
510525
max_length=77,
511526
freeze=True,
512527
layer="last",
513528
):
514529
super().__init__()
530+
device = device or get_default_device_name()
515531
assert layer in self.LAYERS
516532
model, _, _ = open_clip.create_model_and_transforms(
517533
arch, device=torch.device("cpu"), pretrained=version
@@ -576,7 +592,7 @@ def __init__(
576592
self,
577593
arch="ViT-H-14",
578594
version="laion2b_s32b_b79k",
579-
device="cuda",
595+
device=None,
580596
max_length=77,
581597
freeze=True,
582598
antialias=True,
@@ -588,6 +604,7 @@ def __init__(
588604
init_device=None,
589605
):
590606
super().__init__()
607+
device = device or get_default_device_name()
591608
model, _, _ = open_clip.create_model_and_transforms(
592609
arch,
593610
device=torch.device(default(init_device, "cpu")),
@@ -733,11 +750,12 @@ def __init__(
733750
self,
734751
clip_version="openai/clip-vit-large-patch14",
735752
t5_version="google/t5-v1_1-xl",
736-
device="cuda",
753+
device=None,
737754
clip_max_length=77,
738755
t5_max_length=77,
739756
):
740757
super().__init__()
758+
device = device or get_default_device_name()
741759
self.clip_encoder = FrozenCLIPEmbedder(
742760
clip_version, device, max_length=clip_max_length
743761
)
@@ -999,7 +1017,7 @@ def forward(
9991017
noise = torch.randn_like(vid)
10001018
vid = vid + noise * append_dims(sigmas, vid.ndim)
10011019

1002-
with torch.autocast("cuda", enabled=not self.disable_encoder_autocast):
1020+
with safe_autocast(get_default_device_name(), enabled=not self.disable_encoder_autocast):
10031021
n_samples = (
10041022
self.en_and_decode_n_samples_a_time
10051023
if self.en_and_decode_n_samples_a_time is not None

0 commit comments

Comments
 (0)