Skip to content

Commit cfecc41

Browse files
Update docstrings
1 parent ad84715 commit cfecc41

File tree

1 file changed

+22
-5
lines changed

1 file changed

+22
-5
lines changed

diffusion/models/models.py

+22-5
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,9 @@ def stable_diffusion_xl(
366366
use_xformers (bool): Whether to use xformers for attention. Defaults to True.
367367
lora_rank (int, optional): If not None, the rank to use for LoRA finetuning. Defaults to None.
368368
lora_alpha (int, optional): If not None, the alpha to use for LoRA finetuning. Defaults to None.
369+
cache_dir (str): Directory to cache local files in. Default: `'/tmp/hf_files'`.
370+
local_files_only (bool): Whether to only use local files. Default: `False`.
371+
369372
"""
370373
latent_mean, latent_std = _parse_latent_statistics(latent_mean), _parse_latent_statistics(latent_std)
371374

@@ -379,7 +382,9 @@ def stable_diffusion_xl(
379382
val_metrics = [MeanSquaredError()]
380383

381384
# Make the tokenizer and text encoder
382-
tokenizer = MultiTokenizer(tokenizer_names_or_paths=tokenizer_names, cache_dir=cache_dir, local_files_only=local_files_only)
385+
tokenizer = MultiTokenizer(tokenizer_names_or_paths=tokenizer_names,
386+
cache_dir=cache_dir,
387+
local_files_only=local_files_only)
383388
text_encoder = MultiTextEncoder(model_names=text_encoder_names,
384389
encode_latents_in_fp16=encode_latents_in_fp16,
385390
pretrained_sdxl=pretrained,
@@ -412,9 +417,15 @@ def stable_diffusion_xl(
412417
downsample_factor = 2**(len(vae.config['channel_multipliers']) - 1)
413418

414419
# Make the unet
415-
unet_config = PretrainedConfig.get_config_dict(unet_model_name, subfolder='unet', cache_dir=cache_dir, local_files_only=local_files_only)[0]
420+
unet_config = PretrainedConfig.get_config_dict(unet_model_name,
421+
subfolder='unet',
422+
cache_dir=cache_dir,
423+
local_files_only=local_files_only)[0]
416424
if pretrained:
417-
unet = UNet2DConditionModel.from_pretrained(unet_model_name, subfolder='unet', cache_dir=cache_dir, local_files_only=local_files_only)
425+
unet = UNet2DConditionModel.from_pretrained(unet_model_name,
426+
subfolder='unet',
427+
cache_dir=cache_dir,
428+
local_files_only=local_files_only)
418429
if isinstance(vae, AutoEncoder) and vae.config['latent_channels'] != 4:
419430
raise ValueError(f'Pretrained unet has 4 latent channels but the vae has {vae.latent_channels}.')
420431
else:
@@ -667,6 +678,7 @@ def precomputed_text_latent_diffusion(
667678
use_xformers (bool): Whether to use xformers for attention. Defaults to True.
668679
lora_rank (int, optional): If not None, the rank to use for LoRA finetuning. Defaults to None.
669680
lora_alpha (int, optional): If not None, the alpha to use for LoRA finetuning. Defaults to None.
681+
local_files_only (bool): Whether to only use local files. Default: `False`.
670682
"""
671683
latent_mean, latent_std = _parse_latent_statistics(latent_mean), _parse_latent_statistics(latent_std)
672684

@@ -700,7 +712,10 @@ def precomputed_text_latent_diffusion(
700712
downsample_factor = 2**(len(vae.config['channel_multipliers']) - 1)
701713

702714
# Make the unet
703-
unet_config = PretrainedConfig.get_config_dict(unet_model_name, subfolder='unet', cache_dir=cache_dir, local_files_only=local_files_only)[0]
715+
unet_config = PretrainedConfig.get_config_dict(unet_model_name,
716+
subfolder='unet',
717+
cache_dir=cache_dir,
718+
local_files_only=local_files_only)[0]
704719

705720
if isinstance(vae, AutoEncoder):
706721
# Adapt the unet config to account for differing number of latent channels if necessary
@@ -797,7 +812,9 @@ def precomputed_text_latent_diffusion(
797812
if include_text_encoders:
798813
dtype_map = {'float32': torch.float32, 'float16': torch.float16, 'bfloat16': torch.bfloat16}
799814
dtype = dtype_map[text_encoder_dtype]
800-
t5_tokenizer = AutoTokenizer.from_pretrained('google/t5-v1_1-xxl', cache_dir=cache_dir, local_files_only=local_files_only)
815+
t5_tokenizer = AutoTokenizer.from_pretrained('google/t5-v1_1-xxl',
816+
cache_dir=cache_dir,
817+
local_files_only=local_files_only)
801818
clip_tokenizer = AutoTokenizer.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0',
802819
subfolder='tokenizer',
803820
cache_dir=cache_dir,

0 commit comments

Comments
 (0)