Skip to content

Commit 20e41bb

Browse files
Merge pull request #1 from megagonlabs/feature/fall_back_to_hugging_face_hub
fall back to hugging face hub
2 parents e474449 + a6f2953 commit 20e41bb

3 files changed

Lines changed: 17 additions & 13 deletions

File tree

ginza_transformers/pipeline_component.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@
2222
2323
[transformer_custom.model]
2424
@architectures = "ginza-transformers.TransformerModel.v1"
25-
name = "electra-base-ud-japanese-discriminator"
26-
tokenizer_config = {"use_fast": false, "tokenizer_class": "sudachitra.tokenization_electra_sudachipy.ElectraSudachipyTokenizer"}
2725
2826
[transformer_custom.model.get_spans]
2927
@span_getters = "spacy-transformers.strided_spans.v1"
@@ -64,7 +62,7 @@ def from_disk(
6462
def load_model(p):
6563
p = Path(p).absolute()
6664
tokenizer, transformer = huggingface_from_pretrained_custom(
67-
p, self.model.attrs["tokenizer_config"]
65+
p, self.model.attrs["tokenizer_config"], self.model.attrs["name"]
6866
)
6967
self.model.attrs["tokenizer"] = tokenizer
7068
self.model.attrs["set_transformer"](self.model, transformer)

ginza_transformers/util.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
from typing import Dict, Union
1+
from typing import Dict, Union, Optional
22
from pathlib import Path
33
from transformers import AutoModel, AutoTokenizer
44
from thinc.api import get_current_ops, CupyOps
55

66

7-
def huggingface_from_pretrained_custom(source: Union[Path, str], config: Dict):
7+
def huggingface_from_pretrained_custom(source: Union[Path, str], tokenizer_config: Dict, model_name: Optional[str] = None):
88
"""Create a Huggingface transformer model from pretrained weights. Will
99
download the model if it is not already downloaded.
1010
@@ -16,19 +16,25 @@ def huggingface_from_pretrained_custom(source: Union[Path, str], config: Dict):
1616
str_path = str(source.absolute())
1717
else:
1818
str_path = source
19-
19+
2020
try:
21-
tokenizer = AutoTokenizer.from_pretrained(str_path, **config)
21+
tokenizer = AutoTokenizer.from_pretrained(str_path, **tokenizer_config)
2222
except ValueError as e:
23-
if "tokenizer_class" not in config:
23+
if "tokenizer_class" not in tokenizer_config:
2424
raise e
25-
tokenizer_class_name = config["tokenizer_class"].split(".")
25+
tokenizer_class_name = tokenizer_config["tokenizer_class"].split(".")
2626
from importlib import import_module
2727
tokenizer_module = import_module(".".join(tokenizer_class_name[:-1]))
2828
tokenizer_class = getattr(tokenizer_module, tokenizer_class_name[-1])
29-
tokenizer = tokenizer_class(vocab_file=str_path + "/vocab.txt", **config)
29+
tokenizer = tokenizer_class(vocab_file=str_path + "/vocab.txt", **tokenizer_config)
3030

31-
transformer = AutoModel.from_pretrained(str_path)
31+
try:
32+
transformer = AutoModel.from_pretrained(str_path)
33+
except OSError as e:
34+
try:
35+
transformer = AutoModel.from_pretrained(model_name)
36+
except OSError as e2:
37+
raise e
3238
ops = get_current_ops()
3339
if isinstance(ops, CupyOps):
3440
transformer.cuda()

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@
1414
],
1515
},
1616
install_requires=[
17-
"spacy-transformers>=1.0.2",
17+
"spacy-transformers>=1.0.4",
1818
],
1919
license="MIT",
2020
name="ginza-transformers",
2121
packages=find_packages(include=["ginza_transformers", "ginza_transformers.layers"]),
2222
url="https://github.com/megagonlabs/ginza-transformers",
23-
version='0.2.0',
23+
version='0.3.0',
2424
)

0 commit comments

Comments
 (0)