@@ -206,6 +206,52 @@ def view_step(samples, step):
206206 writeUInt (4 ,step )
207207 stdout .flush ()
208208
209+ def preload_models ():
210+ from huggingface_hub .utils .tqdm import tqdm
211+
212+ current_model_name = ""
213+ def start_preloading (model_name ):
214+ nonlocal current_model_name
215+ current_model_name = model_name
216+ writeInfo (f"Downloading { model_name } (0%)" )
217+
218+ def update_decorator (original ):
219+ def update (self , n = 1 ):
220+ result = original (self , n )
221+ nonlocal current_model_name
222+ frac = self .n / self .total
223+ percentage = int (frac * 100 )
224+ if self .n - self .last_print_n >= self .miniters :
225+ writeInfo (f"Downloading { current_model_name } ({ percentage } %)" )
226+ return result
227+ return update
228+ old_update = tqdm .update
229+ tqdm .update = update_decorator (tqdm .update )
230+
231+ import warnings
232+ import transformers
233+ transformers .logging .set_verbosity_error ()
234+
235+ start_preloading ("BERT tokenizer" )
236+ transformers .BertTokenizerFast .from_pretrained ('bert-base-uncased' )
237+
238+ writeInfo ("Preloading `kornia` requirements" )
239+ with warnings .catch_warnings ():
240+ warnings .filterwarnings ('ignore' , category = DeprecationWarning )
241+ import kornia
242+
243+ start_preloading ("CLIP" )
244+ clip_version = 'openai/clip-vit-large-patch14'
245+ transformers .CLIPTokenizer .from_pretrained (clip_version )
246+ transformers .CLIPTextModel .from_pretrained (clip_version )
247+
248+ tqdm .update = old_update
249+
250+ from transformers .utils .hub import TRANSFORMERS_CACHE
251+ model_paths = {'bert-base-uncased' , 'openai--clip-vit-large-patch14' }
252+ if any (not os .path .isdir (os .path .join (TRANSFORMERS_CACHE , f'models--{ path } ' )) for path in model_paths ):
253+ preload_models ()
254+
209255 generator = None
210256 while True :
211257 json_len = int .from_bytes (stdin .read (8 ),sys .byteorder ,signed = False )
0 commit comments