Skip to content

Commit cfc5a14

Browse files
authored
Add preload models step (#150)
* Add preload models step to generator_process main * Only run preload_models if models are detected as missing * Show progress in info
1 parent 89db264 commit cfc5a14

File tree

2 files changed

+48
-0
lines changed

2 files changed

+48
-0
lines changed

generator_process.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,52 @@ def view_step(samples, step):
206206
writeUInt(4,step)
207207
stdout.flush()
208208

209+
def preload_models():
210+
from huggingface_hub.utils.tqdm import tqdm
211+
212+
current_model_name = ""
213+
def start_preloading(model_name):
214+
nonlocal current_model_name
215+
current_model_name = model_name
216+
writeInfo(f"Downloading {model_name} (0%)")
217+
218+
def update_decorator(original):
219+
def update(self, n=1):
220+
result = original(self, n)
221+
nonlocal current_model_name
222+
frac = self.n / self.total
223+
percentage = int(frac * 100)
224+
if self.n - self.last_print_n >= self.miniters:
225+
writeInfo(f"Downloading {current_model_name} ({percentage}%)")
226+
return result
227+
return update
228+
old_update = tqdm.update
229+
tqdm.update = update_decorator(tqdm.update)
230+
231+
import warnings
232+
import transformers
233+
transformers.logging.set_verbosity_error()
234+
235+
start_preloading("BERT tokenizer")
236+
transformers.BertTokenizerFast.from_pretrained('bert-base-uncased')
237+
238+
writeInfo("Preloading `kornia` requirements")
239+
with warnings.catch_warnings():
240+
warnings.filterwarnings('ignore', category=DeprecationWarning)
241+
import kornia
242+
243+
start_preloading("CLIP")
244+
clip_version = 'openai/clip-vit-large-patch14'
245+
transformers.CLIPTokenizer.from_pretrained(clip_version)
246+
transformers.CLIPTextModel.from_pretrained(clip_version)
247+
248+
tqdm.update = old_update
249+
250+
from transformers.utils.hub import TRANSFORMERS_CACHE
251+
model_paths = {'bert-base-uncased', 'openai--clip-vit-large-patch14'}
252+
if any(not os.path.isdir(os.path.join(TRANSFORMERS_CACHE, f'models--{path}')) for path in model_paths):
253+
preload_models()
254+
209255
generator = None
210256
while True:
211257
json_len = int.from_bytes(stdin.read(8),sys.byteorder,signed=False)

operators/dream_texture.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,8 @@ def image_writer(seed, width, height, pixels, upscaled=False):
118118
area.spaces.active.image = image
119119
scene.dream_textures_progress = 0
120120
scene.dream_textures_prompt.seed = str(seed) # update property in case seed was sourced randomly or from hash
121+
history_entry.seed = str(seed)
122+
history_entry.random_seed = False
121123

122124
def view_step(step, width=None, height=None, pixels=None):
123125
info() # clear variable

0 commit comments

Comments
 (0)