@@ -307,6 +307,8 @@ def stable_diffusion_xl(
307
307
use_xformers : bool = True ,
308
308
lora_rank : Optional [int ] = None ,
309
309
lora_alpha : Optional [int ] = None ,
310
+ cache_dir : str = '/tmp/hf_files' ,
311
+ local_files_only : bool = False ,
310
312
):
311
313
"""Stable diffusion 2 training setup + SDXL UNet and VAE.
312
314
@@ -377,10 +379,12 @@ def stable_diffusion_xl(
377
379
val_metrics = [MeanSquaredError ()]
378
380
379
381
# Make the tokenizer and text encoder
380
- tokenizer = MultiTokenizer (tokenizer_names_or_paths = tokenizer_names )
382
+ tokenizer = MultiTokenizer (tokenizer_names_or_paths = tokenizer_names , cache_dir = cache_dir , local_files_only = local_files_only )
381
383
text_encoder = MultiTextEncoder (model_names = text_encoder_names ,
382
384
encode_latents_in_fp16 = encode_latents_in_fp16 ,
383
- pretrained_sdxl = pretrained )
385
+ pretrained_sdxl = pretrained ,
386
+ cache_dir = cache_dir ,
387
+ local_files_only = local_files_only )
384
388
385
389
precision = torch .float16 if encode_latents_in_fp16 else None
386
390
# Make the autoencoder
@@ -408,9 +412,9 @@ def stable_diffusion_xl(
408
412
downsample_factor = 2 ** (len (vae .config ['channel_multipliers' ]) - 1 )
409
413
410
414
# Make the unet
411
- unet_config = PretrainedConfig .get_config_dict (unet_model_name , subfolder = 'unet' )[0 ]
415
+ unet_config = PretrainedConfig .get_config_dict (unet_model_name , subfolder = 'unet' , cache_dir = cache_dir , local_files_only = local_files_only )[0 ]
412
416
if pretrained :
413
- unet = UNet2DConditionModel .from_pretrained (unet_model_name , subfolder = 'unet' )
417
+ unet = UNet2DConditionModel .from_pretrained (unet_model_name , subfolder = 'unet' , cache_dir = cache_dir , local_files_only = local_files_only )
414
418
if isinstance (vae , AutoEncoder ) and vae .config ['latent_channels' ] != 4 :
415
419
raise ValueError (f'Pretrained unet has 4 latent channels but the vae has { vae .latent_channels } .' )
416
420
else :
@@ -612,6 +616,7 @@ def precomputed_text_latent_diffusion(
612
616
use_xformers : bool = True ,
613
617
lora_rank : Optional [int ] = None ,
614
618
lora_alpha : Optional [int ] = None ,
619
+ local_files_only : bool = False ,
615
620
):
616
621
"""Latent diffusion model training using precomputed text latents from T5-XXL and CLIP.
617
622
@@ -695,7 +700,7 @@ def precomputed_text_latent_diffusion(
695
700
downsample_factor = 2 ** (len (vae .config ['channel_multipliers' ]) - 1 )
696
701
697
702
# Make the unet
698
- unet_config = PretrainedConfig .get_config_dict (unet_model_name , subfolder = 'unet' )[0 ]
703
+ unet_config = PretrainedConfig .get_config_dict (unet_model_name , subfolder = 'unet' , cache_dir = cache_dir , local_files_only = local_files_only )[0 ]
699
704
700
705
if isinstance (vae , AutoEncoder ):
701
706
# Adapt the unet config to account for differing number of latent channels if necessary
@@ -792,20 +797,20 @@ def precomputed_text_latent_diffusion(
792
797
if include_text_encoders :
793
798
dtype_map = {'float32' : torch .float32 , 'float16' : torch .float16 , 'bfloat16' : torch .bfloat16 }
794
799
dtype = dtype_map [text_encoder_dtype ]
795
- t5_tokenizer = AutoTokenizer .from_pretrained ('google/t5-v1_1-xxl' , cache_dir = cache_dir , local_files_only = True )
800
+ t5_tokenizer = AutoTokenizer .from_pretrained ('google/t5-v1_1-xxl' , cache_dir = cache_dir , local_files_only = local_files_only )
796
801
clip_tokenizer = AutoTokenizer .from_pretrained ('stabilityai/stable-diffusion-xl-base-1.0' ,
797
802
subfolder = 'tokenizer' ,
798
803
cache_dir = cache_dir ,
799
- local_files_only = False )
804
+ local_files_only = local_files_only )
800
805
t5_encoder = AutoModel .from_pretrained ('google/t5-v1_1-xxl' ,
801
806
torch_dtype = dtype ,
802
807
cache_dir = cache_dir ,
803
- local_files_only = False ).encoder .eval ()
808
+ local_files_only = local_files_only ).encoder .eval ()
804
809
clip_encoder = CLIPTextModel .from_pretrained ('stabilityai/stable-diffusion-xl-base-1.0' ,
805
810
subfolder = 'text_encoder' ,
806
811
torch_dtype = dtype ,
807
812
cache_dir = cache_dir ,
808
- local_files_only = False ).cuda ().eval ()
813
+ local_files_only = local_files_only ).cuda ().eval ()
809
814
# Make the composer model
810
815
model = PrecomputedTextLatentDiffusion (
811
816
unet = unet ,
0 commit comments