Skip to content

Commit ad84715

Browse files
committed
done
1 parent ba8ca02 commit ad84715

File tree

2 files changed

+22
-15
lines changed

2 files changed

+22
-15
lines changed

diffusion/models/models.py

+14-9
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,8 @@ def stable_diffusion_xl(
307307
use_xformers: bool = True,
308308
lora_rank: Optional[int] = None,
309309
lora_alpha: Optional[int] = None,
310+
cache_dir: str = '/tmp/hf_files',
311+
local_files_only: bool = False,
310312
):
311313
"""Stable diffusion 2 training setup + SDXL UNet and VAE.
312314
@@ -377,10 +379,12 @@ def stable_diffusion_xl(
377379
val_metrics = [MeanSquaredError()]
378380

379381
# Make the tokenizer and text encoder
380-
tokenizer = MultiTokenizer(tokenizer_names_or_paths=tokenizer_names)
382+
tokenizer = MultiTokenizer(tokenizer_names_or_paths=tokenizer_names, cache_dir=cache_dir, local_files_only=local_files_only)
381383
text_encoder = MultiTextEncoder(model_names=text_encoder_names,
382384
encode_latents_in_fp16=encode_latents_in_fp16,
383-
pretrained_sdxl=pretrained)
385+
pretrained_sdxl=pretrained,
386+
cache_dir=cache_dir,
387+
local_files_only=local_files_only)
384388

385389
precision = torch.float16 if encode_latents_in_fp16 else None
386390
# Make the autoencoder
@@ -408,9 +412,9 @@ def stable_diffusion_xl(
408412
downsample_factor = 2**(len(vae.config['channel_multipliers']) - 1)
409413

410414
# Make the unet
411-
unet_config = PretrainedConfig.get_config_dict(unet_model_name, subfolder='unet')[0]
415+
unet_config = PretrainedConfig.get_config_dict(unet_model_name, subfolder='unet', cache_dir=cache_dir, local_files_only=local_files_only)[0]
412416
if pretrained:
413-
unet = UNet2DConditionModel.from_pretrained(unet_model_name, subfolder='unet')
417+
unet = UNet2DConditionModel.from_pretrained(unet_model_name, subfolder='unet', cache_dir=cache_dir, local_files_only=local_files_only)
414418
if isinstance(vae, AutoEncoder) and vae.config['latent_channels'] != 4:
415419
raise ValueError(f'Pretrained unet has 4 latent channels but the vae has {vae.latent_channels}.')
416420
else:
@@ -612,6 +616,7 @@ def precomputed_text_latent_diffusion(
612616
use_xformers: bool = True,
613617
lora_rank: Optional[int] = None,
614618
lora_alpha: Optional[int] = None,
619+
local_files_only: bool = False,
615620
):
616621
"""Latent diffusion model training using precomputed text latents from T5-XXL and CLIP.
617622
@@ -695,7 +700,7 @@ def precomputed_text_latent_diffusion(
695700
downsample_factor = 2**(len(vae.config['channel_multipliers']) - 1)
696701

697702
# Make the unet
698-
unet_config = PretrainedConfig.get_config_dict(unet_model_name, subfolder='unet')[0]
703+
unet_config = PretrainedConfig.get_config_dict(unet_model_name, subfolder='unet', cache_dir=cache_dir, local_files_only=local_files_only)[0]
699704

700705
if isinstance(vae, AutoEncoder):
701706
# Adapt the unet config to account for differing number of latent channels if necessary
@@ -792,20 +797,20 @@ def precomputed_text_latent_diffusion(
792797
if include_text_encoders:
793798
dtype_map = {'float32': torch.float32, 'float16': torch.float16, 'bfloat16': torch.bfloat16}
794799
dtype = dtype_map[text_encoder_dtype]
795-
t5_tokenizer = AutoTokenizer.from_pretrained('google/t5-v1_1-xxl', cache_dir=cache_dir, local_files_only=True)
800+
t5_tokenizer = AutoTokenizer.from_pretrained('google/t5-v1_1-xxl', cache_dir=cache_dir, local_files_only=local_files_only)
796801
clip_tokenizer = AutoTokenizer.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0',
797802
subfolder='tokenizer',
798803
cache_dir=cache_dir,
799-
local_files_only=False)
804+
local_files_only=local_files_only)
800805
t5_encoder = AutoModel.from_pretrained('google/t5-v1_1-xxl',
801806
torch_dtype=dtype,
802807
cache_dir=cache_dir,
803-
local_files_only=False).encoder.eval()
808+
local_files_only=local_files_only).encoder.eval()
804809
clip_encoder = CLIPTextModel.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0',
805810
subfolder='text_encoder',
806811
torch_dtype=dtype,
807812
cache_dir=cache_dir,
808-
local_files_only=False).cuda().eval()
813+
local_files_only=local_files_only).cuda().eval()
809814
# Make the composer model
810815
model = PrecomputedTextLatentDiffusion(
811816
unet=unet,

diffusion/models/text_encoder.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ def __init__(
3131
model_dim_keys: Optional[Union[str, List[str]]] = None,
3232
encode_latents_in_fp16: bool = True,
3333
pretrained_sdxl: bool = False,
34+
cache_dir: str = '/tmp/hf_files',
35+
local_files_only: bool = False
3436
):
3537
super().__init__()
3638
self.pretrained_sdxl = pretrained_sdxl
@@ -50,7 +52,7 @@ def __init__(
5052
name_split = model_name.split('/')
5153
base_name = '/'.join(name_split[:2])
5254
subfolder = '/'.join(name_split[2:])
53-
text_encoder_config = PretrainedConfig.get_config_dict(base_name, subfolder=subfolder)[0]
55+
text_encoder_config = PretrainedConfig.get_config_dict(base_name, subfolder=subfolder, cache_dir=cache_dir, local_files_only=local_files_only)[0]
5456

5557
# Add text_encoder output dim to total dim
5658
dim_found = False
@@ -70,14 +72,14 @@ def __init__(
7072
architectures = text_encoder_config['architectures']
7173
if architectures == ['CLIPTextModel']:
7274
self.text_encoders.append(
73-
CLIPTextModel.from_pretrained(base_name, subfolder=subfolder, torch_dtype=torch_dtype))
75+
CLIPTextModel.from_pretrained(base_name, subfolder=subfolder, torch_dtype=torch_dtype, cache_dir=cache_dir, local_files_only=local_files_only))
7476
elif architectures == ['CLIPTextModelWithProjection']:
7577
self.text_encoders.append(
7678
CLIPTextModelWithProjection.from_pretrained(base_name, subfolder=subfolder,
77-
torch_dtype=torch_dtype))
79+
torch_dtype=torch_dtype, cache_dir=cache_dir, local_files_only=local_files_only))
7880
else:
7981
self.text_encoders.append(
80-
AutoModel.from_pretrained(base_name, subfolder=subfolder, torch_dtype=torch_dtype))
82+
AutoModel.from_pretrained(base_name, subfolder=subfolder, torch_dtype=torch_dtype, cache_dir=cache_dir, local_files_only=local_files_only))
8183
self.architectures += architectures
8284

8385
@property
@@ -125,7 +127,7 @@ class MultiTokenizer:
125127
"org_name/repo_name/subfolder" where the subfolder is excluded if it is not used in the repo.
126128
"""
127129

128-
def __init__(self, tokenizer_names_or_paths: Union[str, Tuple[str, ...]]):
130+
def __init__(self, tokenizer_names_or_paths: Union[str, Tuple[str, ...]], cache_dir: str = '/tmp/hf_files', local_files_only: bool = False):
129131
if isinstance(tokenizer_names_or_paths, str):
130132
tokenizer_names_or_paths = (tokenizer_names_or_paths,)
131133

@@ -134,7 +136,7 @@ def __init__(self, tokenizer_names_or_paths: Union[str, Tuple[str, ...]]):
134136
path_split = tokenizer_name_or_path.split('/')
135137
base_name = '/'.join(path_split[:2])
136138
subfolder = '/'.join(path_split[2:])
137-
self.tokenizers.append(AutoTokenizer.from_pretrained(base_name, subfolder=subfolder))
139+
self.tokenizers.append(AutoTokenizer.from_pretrained(base_name, subfolder=subfolder, cache_dir=cache_dir, local_files_only=local_files_only))
138140

139141
self.model_max_length = min([t.model_max_length for t in self.tokenizers])
140142

0 commit comments

Comments
 (0)