2020from ...modules .diffusionmodules .util import (extract_into_tensor ,
2121 make_beta_schedule )
2222from ...modules .distributions .distributions import DiagonalGaussianDistribution
23- from ...util import (append_dims , autocast , count_params , default ,
24- disabled_train , expand_dims_like , instantiate_from_config )
23+ from ...util import (
24+ append_dims ,
25+ autocast ,
26+ count_params ,
27+ default ,
28+ disabled_train ,
29+ expand_dims_like ,
30+ get_default_device_name ,
31+ instantiate_from_config ,
32+ safe_autocast ,
33+ )
2534
2635
2736class AbstractEmbModel (nn .Module ):
@@ -225,7 +234,9 @@ def forward(self, c):
225234 c = c [:, None , :]
226235 return c
227236
228- def get_unconditional_conditioning (self , bs , device = "cuda" ):
237+ def get_unconditional_conditioning (self , bs , device = None ):
238+ if device is None :
239+ device = get_default_device_name ()
229240 uc_class = (
230241 self .n_classes - 1
231242 ) # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
@@ -250,9 +261,10 @@ class FrozenT5Embedder(AbstractEmbModel):
250261 """Uses the T5 transformer encoder for text"""
251262
252263 def __init__ (
253- self , version = "google/t5-v1_1-xxl" , device = "cuda" , max_length = 77 , freeze = True
264+ self , version = "google/t5-v1_1-xxl" , device = None , max_length = 77 , freeze = True
254265 ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
255266 super ().__init__ ()
267+ device = device or get_default_device_name ()
256268 self .tokenizer = T5Tokenizer .from_pretrained (version )
257269 self .transformer = T5EncoderModel .from_pretrained (version )
258270 self .device = device
@@ -277,7 +289,7 @@ def forward(self, text):
277289 return_tensors = "pt" ,
278290 )
279291 tokens = batch_encoding ["input_ids" ].to (self .device )
280- with torch . autocast ( "cuda" , enabled = False ):
292+ with safe_autocast ( get_default_device_name () , enabled = False ):
281293 outputs = self .transformer (input_ids = tokens )
282294 z = outputs .last_hidden_state
283295 return z
@@ -292,9 +304,10 @@ class FrozenByT5Embedder(AbstractEmbModel):
292304 """
293305
294306 def __init__ (
295- self , version = "google/byt5-base" , device = "cuda" , max_length = 77 , freeze = True
307+ self , version = "google/byt5-base" , device = None , max_length = 77 , freeze = True
296308 ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
297309 super ().__init__ ()
310+ device = device or get_default_device_name ()
298311 self .tokenizer = ByT5Tokenizer .from_pretrained (version )
299312 self .transformer = T5EncoderModel .from_pretrained (version )
300313 self .device = device
@@ -319,7 +332,7 @@ def forward(self, text):
319332 return_tensors = "pt" ,
320333 )
321334 tokens = batch_encoding ["input_ids" ].to (self .device )
322- with torch . autocast ( "cuda" , enabled = False ):
335+ with safe_autocast ( get_default_device_name () , enabled = False ):
323336 outputs = self .transformer (input_ids = tokens )
324337 z = outputs .last_hidden_state
325338 return z
@@ -336,14 +349,15 @@ class FrozenCLIPEmbedder(AbstractEmbModel):
336349 def __init__ (
337350 self ,
338351 version = "openai/clip-vit-large-patch14" ,
339- device = "cuda" ,
352+ device = None ,
340353 max_length = 77 ,
341354 freeze = True ,
342355 layer = "last" ,
343356 layer_idx = None ,
344357 always_return_pooled = False ,
345358 ): # clip-vit-base-patch32
346359 super ().__init__ ()
360+ device = device or get_default_device_name ()
347361 assert layer in self .LAYERS
348362 self .tokenizer = CLIPTokenizer .from_pretrained (version )
349363 self .transformer = CLIPTextModel .from_pretrained (version )
@@ -404,14 +418,15 @@ def __init__(
404418 self ,
405419 arch = "ViT-H-14" ,
406420 version = "laion2b_s32b_b79k" ,
407- device = "cuda" ,
421+ device = None ,
408422 max_length = 77 ,
409423 freeze = True ,
410424 layer = "last" ,
411425 always_return_pooled = False ,
412426 legacy = True ,
413427 ):
414428 super ().__init__ ()
429+ device = device or get_default_device_name ()
415430 assert layer in self .LAYERS
416431 model , _ , _ = open_clip .create_model_and_transforms (
417432 arch ,
@@ -506,12 +521,13 @@ def __init__(
506521 self ,
507522 arch = "ViT-H-14" ,
508523 version = "laion2b_s32b_b79k" ,
509- device = "cuda" ,
524+ device = None ,
510525 max_length = 77 ,
511526 freeze = True ,
512527 layer = "last" ,
513528 ):
514529 super ().__init__ ()
530+ device = device or get_default_device_name ()
515531 assert layer in self .LAYERS
516532 model , _ , _ = open_clip .create_model_and_transforms (
517533 arch , device = torch .device ("cpu" ), pretrained = version
@@ -576,7 +592,7 @@ def __init__(
576592 self ,
577593 arch = "ViT-H-14" ,
578594 version = "laion2b_s32b_b79k" ,
579- device = "cuda" ,
595+ device = None ,
580596 max_length = 77 ,
581597 freeze = True ,
582598 antialias = True ,
@@ -588,6 +604,7 @@ def __init__(
588604 init_device = None ,
589605 ):
590606 super ().__init__ ()
607+ device = device or get_default_device_name ()
591608 model , _ , _ = open_clip .create_model_and_transforms (
592609 arch ,
593610 device = torch .device (default (init_device , "cpu" )),
@@ -733,11 +750,12 @@ def __init__(
733750 self ,
734751 clip_version = "openai/clip-vit-large-patch14" ,
735752 t5_version = "google/t5-v1_1-xl" ,
736- device = "cuda" ,
753+ device = None ,
737754 clip_max_length = 77 ,
738755 t5_max_length = 77 ,
739756 ):
740757 super ().__init__ ()
758+ device = device or get_default_device_name ()
741759 self .clip_encoder = FrozenCLIPEmbedder (
742760 clip_version , device , max_length = clip_max_length
743761 )
@@ -999,7 +1017,7 @@ def forward(
9991017 noise = torch .randn_like (vid )
10001018 vid = vid + noise * append_dims (sigmas , vid .ndim )
10011019
1002- with torch . autocast ( "cuda" , enabled = not self .disable_encoder_autocast ):
1020+ with safe_autocast ( get_default_device_name () , enabled = not self .disable_encoder_autocast ):
10031021 n_samples = (
10041022 self .en_and_decode_n_samples_a_time
10051023 if self .en_and_decode_n_samples_a_time is not None
0 commit comments