1- from typing import Dict , Union
1+ from typing import Dict , Union , Optional
22from pathlib import Path
33from transformers import AutoModel , AutoTokenizer
44from thinc .api import get_current_ops , CupyOps
55
66
7- def huggingface_from_pretrained_custom (source : Union [Path , str ], config : Dict ):
7+ def huggingface_from_pretrained_custom (source : Union [Path , str ], tokenizer_config : Dict , model_name : Optional [ str ] = None ):
88 """Create a Huggingface transformer model from pretrained weights. Will
99 download the model if it is not already downloaded.
1010
@@ -16,19 +16,25 @@ def huggingface_from_pretrained_custom(source: Union[Path, str], config: Dict):
1616 str_path = str (source .absolute ())
1717 else :
1818 str_path = source
19-
19+
2020 try :
21- tokenizer = AutoTokenizer .from_pretrained (str_path , ** config )
21+ tokenizer = AutoTokenizer .from_pretrained (str_path , ** tokenizer_config )
2222 except ValueError as e :
23- if "tokenizer_class" not in config :
23+ if "tokenizer_class" not in tokenizer_config :
2424 raise e
25- tokenizer_class_name = config ["tokenizer_class" ].split ("." )
25+ tokenizer_class_name = tokenizer_config ["tokenizer_class" ].split ("." )
2626 from importlib import import_module
2727 tokenizer_module = import_module ("." .join (tokenizer_class_name [:- 1 ]))
2828 tokenizer_class = getattr (tokenizer_module , tokenizer_class_name [- 1 ])
29- tokenizer = tokenizer_class (vocab_file = str_path + "/vocab.txt" , ** config )
29+ tokenizer = tokenizer_class (vocab_file = str_path + "/vocab.txt" , ** tokenizer_config )
3030
31- transformer = AutoModel .from_pretrained (str_path )
31+ try :
32+ transformer = AutoModel .from_pretrained (str_path )
33+ except OSError as e :
34+ try :
35+ transformer = AutoModel .from_pretrained (model_name )
36+ except OSError as e2 :
37+ raise e
3238 ops = get_current_ops ()
3339 if isinstance (ops , CupyOps ):
3440 transformer .cuda ()
0 commit comments