@@ -366,6 +366,9 @@ def stable_diffusion_xl(
366
366
use_xformers (bool): Whether to use xformers for attention. Defaults to True.
367
367
lora_rank (int, optional): If not None, the rank to use for LoRA finetuning. Defaults to None.
368
368
lora_alpha (int, optional): If not None, the alpha to use for LoRA finetuning. Defaults to None.
369
+ cache_dir (str): Directory to cache local files in. Default: `'/tmp/hf_files'`.
370
+ local_files_only (bool): Whether to only use local files. Default: `False`.
371
+
369
372
"""
370
373
latent_mean , latent_std = _parse_latent_statistics (latent_mean ), _parse_latent_statistics (latent_std )
371
374
@@ -379,7 +382,9 @@ def stable_diffusion_xl(
379
382
val_metrics = [MeanSquaredError ()]
380
383
381
384
# Make the tokenizer and text encoder
382
- tokenizer = MultiTokenizer (tokenizer_names_or_paths = tokenizer_names , cache_dir = cache_dir , local_files_only = local_files_only )
385
+ tokenizer = MultiTokenizer (tokenizer_names_or_paths = tokenizer_names ,
386
+ cache_dir = cache_dir ,
387
+ local_files_only = local_files_only )
383
388
text_encoder = MultiTextEncoder (model_names = text_encoder_names ,
384
389
encode_latents_in_fp16 = encode_latents_in_fp16 ,
385
390
pretrained_sdxl = pretrained ,
@@ -412,9 +417,15 @@ def stable_diffusion_xl(
412
417
downsample_factor = 2 ** (len (vae .config ['channel_multipliers' ]) - 1 )
413
418
414
419
# Make the unet
415
- unet_config = PretrainedConfig .get_config_dict (unet_model_name , subfolder = 'unet' , cache_dir = cache_dir , local_files_only = local_files_only )[0 ]
420
+ unet_config = PretrainedConfig .get_config_dict (unet_model_name ,
421
+ subfolder = 'unet' ,
422
+ cache_dir = cache_dir ,
423
+ local_files_only = local_files_only )[0 ]
416
424
if pretrained :
417
- unet = UNet2DConditionModel .from_pretrained (unet_model_name , subfolder = 'unet' , cache_dir = cache_dir , local_files_only = local_files_only )
425
+ unet = UNet2DConditionModel .from_pretrained (unet_model_name ,
426
+ subfolder = 'unet' ,
427
+ cache_dir = cache_dir ,
428
+ local_files_only = local_files_only )
418
429
if isinstance (vae , AutoEncoder ) and vae .config ['latent_channels' ] != 4 :
419
430
raise ValueError (f'Pretrained unet has 4 latent channels but the vae has { vae .latent_channels } .' )
420
431
else :
@@ -667,6 +678,7 @@ def precomputed_text_latent_diffusion(
667
678
use_xformers (bool): Whether to use xformers for attention. Defaults to True.
668
679
lora_rank (int, optional): If not None, the rank to use for LoRA finetuning. Defaults to None.
669
680
lora_alpha (int, optional): If not None, the alpha to use for LoRA finetuning. Defaults to None.
681
+ local_files_only (bool): Whether to only use local files. Default: `False`.
670
682
"""
671
683
latent_mean , latent_std = _parse_latent_statistics (latent_mean ), _parse_latent_statistics (latent_std )
672
684
@@ -700,7 +712,10 @@ def precomputed_text_latent_diffusion(
700
712
downsample_factor = 2 ** (len (vae .config ['channel_multipliers' ]) - 1 )
701
713
702
714
# Make the unet
703
- unet_config = PretrainedConfig .get_config_dict (unet_model_name , subfolder = 'unet' , cache_dir = cache_dir , local_files_only = local_files_only )[0 ]
715
+ unet_config = PretrainedConfig .get_config_dict (unet_model_name ,
716
+ subfolder = 'unet' ,
717
+ cache_dir = cache_dir ,
718
+ local_files_only = local_files_only )[0 ]
704
719
705
720
if isinstance (vae , AutoEncoder ):
706
721
# Adapt the unet config to account for differing number of latent channels if necessary
@@ -797,7 +812,9 @@ def precomputed_text_latent_diffusion(
797
812
if include_text_encoders :
798
813
dtype_map = {'float32' : torch .float32 , 'float16' : torch .float16 , 'bfloat16' : torch .bfloat16 }
799
814
dtype = dtype_map [text_encoder_dtype ]
800
- t5_tokenizer = AutoTokenizer .from_pretrained ('google/t5-v1_1-xxl' , cache_dir = cache_dir , local_files_only = local_files_only )
815
+ t5_tokenizer = AutoTokenizer .from_pretrained ('google/t5-v1_1-xxl' ,
816
+ cache_dir = cache_dir ,
817
+ local_files_only = local_files_only )
801
818
clip_tokenizer = AutoTokenizer .from_pretrained ('stabilityai/stable-diffusion-xl-base-1.0' ,
802
819
subfolder = 'tokenizer' ,
803
820
cache_dir = cache_dir ,
0 commit comments