Skip to content

Commit 1c19b9b

Browse files
committed
fixed loading error for customized pipelines + add conllu conversion for trankit outputs
1 parent cbbbd30 commit 1c19b9b

File tree

8 files changed

+211
-38
lines changed

8 files changed

+211
-38
lines changed

docs/source/training.md

+28-1
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,8 @@ import trankit
118118

119119
trankit.verify_customized_pipeline(
120120
category='customized-mwt-ner', # pipeline category
121-
save_dir='./save_dir' # directory used for saving models in previous steps
121+
save_dir='./save_dir', # directory used for saving models in previous steps
122+
embedding_name='xlm-roberta-base' # embedding version that we use for training our customized pipeline, by default, it is `xlm-roberta-base`
122123
)
123124
```
124125
If the verification is success, this would printout the following:
@@ -130,3 +131,29 @@ from trankit import Pipeline
130131
p = Pipeline(lang='customized-mwt-ner', cache_dir='./save_dir')
131132
```
132133
From now on, the customized pipeline can be used as a normal pretrained pipeline.
134+
135+
The verification would fail if some of the expected model files of the pipeline are missing. This can be solved via the handy function `download_missing_files`, which is created for borrowing model files from pretrained pipelines provided by Trankit. Suppose that the language of your customized pipeline is English, the function can be used as below:
136+
```
137+
import trankit
138+
139+
trankit.download_missing_files(
140+
category='customized-ner',
141+
save_dir='./save_dir',
142+
embedding_name='xlm-roberta-base',
143+
language='english'
144+
)
145+
```
146+
where `category` is the category that we specified for the customized pipeline, `save_dir` is the path to the directory that we saved the customized models, `embedding_name` is the embedding that we used for the customized pipeline (which is `xlm-roberta-base` by default if we did not specify this in the training process), and `language` is the language with the pretrained models that we want to borrow. For example, if we only trained a NER model for the customized pipeline, the snippet above would borrow the trained models for all the other pipeline components and print out the following message:
147+
```
148+
Missing ./save_dir/xlm-roberta-base/customized-ner/customized-ner.tokenizer.mdl
149+
Missing ./save_dir/xlm-roberta-base/customized-ner/customized-ner.tagger.mdl
150+
Missing ./save_dir/xlm-roberta-base/customized-ner/customized-ner.vocabs.json
151+
Missing ./save_dir/xlm-roberta-base/customized-ner/customized-ner_lemmatizer.pt
152+
http://nlp.uoregon.edu/download/trankit/v1.0.0/xlm-roberta-base/english.zip
153+
Downloading: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 47.9M/47.9M [00:00<00:00, 114MiB/s]
154+
Copying ./save_dir/xlm-roberta-base/english/english.tokenizer.mdl to ./save_dir/xlm-roberta-base/customized-ner/customized-ner.tokenizer.mdl
155+
Copying ./save_dir/xlm-roberta-base/english/english.tagger.mdl to ./save_dir/xlm-roberta-base/customized-ner/customized-ner.tagger.mdl
156+
Copying ./save_dir/xlm-roberta-base/english/english.vocabs.json to ./save_dir/xlm-roberta-base/customized-ner/customized-ner.vocabs.json
157+
Copying ./save_dir/xlm-roberta-base/english/english_lemmatizer.pt to ./save_dir/xlm-roberta-base/customized-ner/customized-ner_lemmatizer.pt
158+
```
159+
After this, we can go back to do the verification step again.

trankit/__init__.py

+101-31
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,119 @@
11
from .pipeline import Pipeline
22
from .tpipeline import TPipeline
33
from .pipeline import supported_langs, langwithner, remove_with_path
4+
from .utils.base_utils import download, trankit2conllu
5+
from .utils.tbinfo import supported_embeddings, supported_langs, saved_model_version
6+
import os
7+
from shutil import copyfile
48

5-
__version__ = "1.0.1"
9+
__version__ = "1.1.0"
610

711

8-
def verify_customized_pipeline(category, save_dir):
12+
def download_missing_files(category, save_dir, embedding_name, language):
13+
assert language in supported_langs, '{} is not a pretrained language. Current pretrained languages: {}'.format(language, supported_langs)
14+
assert embedding_name in supported_embeddings, '{} has not been supported. Current supported embeddings: {}'.format(embedding_name, supported_embeddings)
15+
916
import os
1017
assert category in {'customized', 'customized-ner', 'customized-mwt',
11-
'customized-mwt-ner'}, "Pipeline category must be one of the following: 'customized', 'customized-ner', 'customized-mwt', 'customized-mwt-ner'"
18+
'customized-mwt-ner'}, "Pipeline category must be one of the following: 'customized', 'customized-ner', 'customized-mwt', 'customized-mwt-ner'"
19+
if category == 'customized':
20+
file_list = [
21+
('{}.tokenizer.mdl', os.path.join(save_dir, embedding_name, category, '{}.tokenizer.mdl'.format(category))),
22+
('{}.tagger.mdl', os.path.join(save_dir, embedding_name, category, '{}.tagger.mdl'.format(category))),
23+
('{}.vocabs.json', os.path.join(save_dir, embedding_name, category, '{}.vocabs.json'.format(category))),
24+
('{}_lemmatizer.pt', os.path.join(save_dir, embedding_name, category, '{}_lemmatizer.pt'.format(category)))
25+
]
26+
elif category == 'customized-ner':
27+
file_list = [
28+
('{}.tokenizer.mdl', os.path.join(save_dir, embedding_name, category, '{}.tokenizer.mdl'.format(category))),
29+
('{}.tagger.mdl', os.path.join(save_dir, embedding_name, category, '{}.tagger.mdl'.format(category))),
30+
('{}.vocabs.json', os.path.join(save_dir, embedding_name, category, '{}.vocabs.json'.format(category))),
31+
('{}_lemmatizer.pt', os.path.join(save_dir, embedding_name, category, '{}_lemmatizer.pt'.format(category))),
32+
('{}.ner.mdl', os.path.join(save_dir, embedding_name, category, '{}.ner.mdl'.format(category))),
33+
('{}.ner-vocab.json', os.path.join(save_dir, embedding_name, category, '{}.ner-vocab.json'.format(category)))
34+
]
35+
elif category == 'customized-mwt':
36+
file_list = [
37+
('{}.tokenizer.mdl', os.path.join(save_dir, embedding_name, category, '{}.tokenizer.mdl'.format(category))),
38+
('{}_mwt_expander.pt', os.path.join(save_dir, embedding_name, category, '{}_mwt_expander.pt'.format(category))),
39+
('{}.tagger.mdl', os.path.join(save_dir, embedding_name, category, '{}.tagger.mdl'.format(category))),
40+
('{}.vocabs.json', os.path.join(save_dir, embedding_name, category, '{}.vocabs.json'.format(category))),
41+
('{}_lemmatizer.pt', os.path.join(save_dir, embedding_name, category, '{}_lemmatizer.pt'.format(category)))
42+
]
43+
elif category == 'customized-mwt-ner':
44+
file_list = [
45+
('{}.tokenizer.mdl', os.path.join(save_dir, embedding_name, category, '{}.tokenizer.mdl'.format(category))),
46+
('{}_mwt_expander.pt', os.path.join(save_dir, embedding_name, category, '{}_mwt_expander.pt'.format(category))),
47+
('{}.tagger.mdl', os.path.join(save_dir, embedding_name, category, '{}.tagger.mdl'.format(category))),
48+
('{}.vocabs.json', os.path.join(save_dir, embedding_name, category, '{}.vocabs.json'.format(category))),
49+
('{}_lemmatizer.pt', os.path.join(save_dir, embedding_name, category, '{}_lemmatizer.pt'.format(category))),
50+
('{}.ner.mdl', os.path.join(save_dir, embedding_name, category, '{}.ner.mdl'.format(category))),
51+
('{}.ner-vocab.json', os.path.join(save_dir, embedding_name, category, '{}.ner-vocab.json'.format(category)))
52+
]
53+
else:
54+
assert 'Unknown customized lang!'
55+
missing_filenamess = []
56+
for filename, filepath in file_list:
57+
if not os.path.exists(filepath):
58+
print('Missing {}'.format(filepath))
59+
missing_filenamess.append(filename)
60+
61+
download(
62+
cache_dir=save_dir,
63+
language=language,
64+
saved_model_version=saved_model_version, # manually set this to avoid duplicated storage
65+
embedding_name=embedding_name
66+
)
67+
# borrow pretrained files
68+
src_dir = os.path.join(save_dir, embedding_name, language)
69+
tgt_dir = os.path.join(save_dir, embedding_name, category)
70+
for fname in missing_filenamess:
71+
copyfile(os.path.join(src_dir, fname.format(language)), os.path.join(tgt_dir, fname.format(category)))
72+
print('Copying {} to {}'.format(
73+
os.path.join(src_dir, fname.format(language)),
74+
os.path.join(tgt_dir, fname.format(category))
75+
))
76+
remove_with_path(src_dir)
77+
78+
79+
def verify_customized_pipeline(category, save_dir, embedding_name):
80+
assert embedding_name in supported_embeddings, '{} has not been supported. Current supported embeddings: {}'.format(
81+
embedding_name, supported_embeddings)
82+
assert category in {'customized', 'customized-ner', 'customized-mwt',
83+
'customized-mwt-ner'}, "Pipeline category must be one of the following: 'customized', 'customized-ner', 'customized-mwt', 'customized-mwt-ner'"
1284
if category == 'customized':
1385
file_list = [
14-
os.path.join(save_dir, category, '{}.tokenizer.mdl'.format(category)),
15-
os.path.join(save_dir, category, '{}.tagger.mdl'.format(category)),
16-
os.path.join(save_dir, category, '{}.vocabs.json'.format(category)),
17-
os.path.join(save_dir, category, '{}_lemmatizer.pt'.format(category))
86+
os.path.join(save_dir, embedding_name, category, '{}.tokenizer.mdl'.format(category)),
87+
os.path.join(save_dir, embedding_name, category, '{}.tagger.mdl'.format(category)),
88+
os.path.join(save_dir, embedding_name, category, '{}.vocabs.json'.format(category)),
89+
os.path.join(save_dir, embedding_name, category, '{}_lemmatizer.pt'.format(category))
1890
]
1991
elif category == 'customized-ner':
2092
file_list = [
21-
os.path.join(save_dir, category, '{}.tokenizer.mdl'.format(category)),
22-
os.path.join(save_dir, category, '{}.tagger.mdl'.format(category)),
23-
os.path.join(save_dir, category, '{}.vocabs.json'.format(category)),
24-
os.path.join(save_dir, category, '{}_lemmatizer.pt'.format(category)),
25-
os.path.join(save_dir, category, '{}.ner.mdl'.format(category)),
26-
os.path.join(save_dir, category, '{}.ner-vocab.json'.format(category))
93+
os.path.join(save_dir, embedding_name, category, '{}.tokenizer.mdl'.format(category)),
94+
os.path.join(save_dir, embedding_name, category, '{}.tagger.mdl'.format(category)),
95+
os.path.join(save_dir, embedding_name, category, '{}.vocabs.json'.format(category)),
96+
os.path.join(save_dir, embedding_name, category, '{}_lemmatizer.pt'.format(category)),
97+
os.path.join(save_dir, embedding_name, category, '{}.ner.mdl'.format(category)),
98+
os.path.join(save_dir, embedding_name, category, '{}.ner-vocab.json'.format(category))
2799
]
28100
elif category == 'customized-mwt':
29101
file_list = [
30-
os.path.join(save_dir, category, '{}.tokenizer.mdl'.format(category)),
31-
os.path.join(save_dir, category, '{}_mwt_expander.pt'.format(category)),
32-
os.path.join(save_dir, category, '{}.tagger.mdl'.format(category)),
33-
os.path.join(save_dir, category, '{}.vocabs.json'.format(category)),
34-
os.path.join(save_dir, category, '{}_lemmatizer.pt'.format(category))
102+
os.path.join(save_dir, embedding_name, category, '{}.tokenizer.mdl'.format(category)),
103+
os.path.join(save_dir, embedding_name, category, '{}_mwt_expander.pt'.format(category)),
104+
os.path.join(save_dir, embedding_name, category, '{}.tagger.mdl'.format(category)),
105+
os.path.join(save_dir, embedding_name, category, '{}.vocabs.json'.format(category)),
106+
os.path.join(save_dir, embedding_name, category, '{}_lemmatizer.pt'.format(category))
35107
]
36108
elif category == 'customized-mwt-ner':
37109
file_list = [
38-
os.path.join(save_dir, category, '{}.tokenizer.mdl'.format(category)),
39-
os.path.join(save_dir, category, '{}_mwt_expander.pt'.format(category)),
40-
os.path.join(save_dir, category, '{}.tagger.mdl'.format(category)),
41-
os.path.join(save_dir, category, '{}.vocabs.json'.format(category)),
42-
os.path.join(save_dir, category, '{}_lemmatizer.pt'.format(category)),
43-
os.path.join(save_dir, category, '{}.ner.mdl'.format(category)),
44-
os.path.join(save_dir, category, '{}.ner-vocab.json'.format(category))
110+
os.path.join(save_dir, embedding_name, category, '{}.tokenizer.mdl'.format(category)),
111+
os.path.join(save_dir, embedding_name, category, '{}_mwt_expander.pt'.format(category)),
112+
os.path.join(save_dir, embedding_name, category, '{}.tagger.mdl'.format(category)),
113+
os.path.join(save_dir, embedding_name, category, '{}.vocabs.json'.format(category)),
114+
os.path.join(save_dir, embedding_name, category, '{}_lemmatizer.pt'.format(category)),
115+
os.path.join(save_dir, embedding_name, category, '{}.ner.mdl'.format(category)),
116+
os.path.join(save_dir, embedding_name, category, '{}.ner-vocab.json'.format(category))
45117
]
46118
else:
47119
assert 'Unknown customized lang!'
@@ -52,13 +124,11 @@ def verify_customized_pipeline(category, save_dir):
52124
verified = False
53125
print('Missing {}'.format(filepath))
54126
if verified:
55-
with open(os.path.join(save_dir, category, '{}.downloaded'.format(category)), 'w') as f:
127+
with open(os.path.join(save_dir, embedding_name, category, '{}.downloaded'.format(category)), 'w') as f:
56128
f.write('')
57-
remove_with_path(os.path.join(save_dir, category, 'train.txt.character'))
58-
remove_with_path(os.path.join(save_dir, category, 'logs'))
59-
remove_with_path(os.path.join(save_dir, category, 'preds'))
60-
remove_with_path(os.path.join(save_dir, category, 'xlm-roberta-large'))
61-
remove_with_path(os.path.join(save_dir, category, 'xlm-roberta-base'))
129+
remove_with_path(os.path.join(save_dir, embedding_name, category, 'train.txt.character'))
130+
remove_with_path(os.path.join(save_dir, embedding_name, category, 'logs'))
131+
remove_with_path(os.path.join(save_dir, embedding_name, category, 'preds'))
62132
print(
63133
"Customized pipeline is ready to use!\nIt can be initialized as follows:\n-----------------------------------\nfrom trankit import Pipeline\np = Pipeline(lang='{}', cache_dir='{}')".format(
64134
category, save_dir))

trankit/pipeline.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def __init__(self, lang, cache_dir=None, gpu=True, embedding='xlm-roberta-base')
7070
download(
7171
cache_dir=self._config._cache_dir,
7272
language=lang,
73-
saved_model_version='v1.0.0', # manually set this to avoid duplicated storage
73+
saved_model_version=saved_model_version, # manually set this to avoid duplicated storage
7474
embedding_name=master_config.embedding_name
7575
)
7676

trankit/tests/test_training.py

+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import trankit
2+
3+
# initialize a trainer for the task
4+
trainer = trankit.TPipeline(
5+
training_config={
6+
'category': 'customized-ner', # pipeline category
7+
'task': 'ner', # task name
8+
'save_dir': './save_dir', # directory to save the trained model
9+
'train_bio_fpath': './train.bio', # training data in BIO format
10+
'dev_bio_fpath': './dev.bio', # training data in BIO format
11+
'max_epoch': 1
12+
}
13+
)
14+
15+
# start training
16+
trainer.train()
17+
18+
trankit.download_missing_files(
19+
category='customized-ner',
20+
save_dir='./save_dir',
21+
embedding_name='xlm-roberta-base',
22+
language='english'
23+
)
24+
25+
trankit.verify_customized_pipeline(
26+
category='customized-ner', # pipeline category
27+
save_dir='./save_dir', # directory used for saving models in previous steps
28+
embedding_name='xlm-roberta-base'
29+
)
30+
31+
p = trankit.Pipeline(lang='customized-ner', cache_dir='./save_dir')
32+
33+
print(trankit.trankit2conllu(p('I love you more than I can say. Do you know that?')))

trankit/tpipeline.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def _set_up_config(self, training_config):
160160

161161
# device and save dir
162162
self._save_dir = training_config['save_dir'] if 'save_dir' in training_config else './cache/'
163-
self._save_dir = os.path.join(self._save_dir, self._lang)
163+
self._save_dir = os.path.join(self._save_dir, master_config.embedding_name, self._lang)
164164
self._cache_dir = self._save_dir
165165
self._gpu = training_config['gpu'] if 'gpu' in training_config else True
166166
self._use_gpu = training_config['gpu'] if 'gpu' in training_config else True
@@ -211,9 +211,7 @@ def _set_up_config(self, training_config):
211211
# wordpiece splitter
212212
if self._task not in ['mwt', 'lemmatize']:
213213
master_config.wordpiece_splitter = XLMRobertaTokenizer.from_pretrained(master_config.embedding_name,
214-
cache_dir=os.path.join(
215-
master_config._save_dir,
216-
master_config.embedding_name))
214+
cache_dir=master_config._save_dir)
217215

218216
def _prepare_tokenize(self):
219217
self.train_set = TokenizeDataset(

trankit/utils/base_utils.py

+36-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,40 @@
1818
SPACE_RE = re.compile(r'\s')
1919

2020

21+
def trankit2conllu(trankit_output):
22+
assert type(trankit_output) == dict, "`trankit_output` must be a Python dictionary!"
23+
if SENTENCES in trankit_output and len(trankit_output[SENTENCES]) > 0 and TOKENS in trankit_output[SENTENCES][0]:
24+
output_type = 'document'
25+
elif TOKENS in trankit_output:
26+
output_type = 'sentence'
27+
else:
28+
print("Unknown format of `trankit_output`!")
29+
return None
30+
try:
31+
if output_type == 'document':
32+
json_doc = trankit_output[SENTENCES]
33+
else:
34+
assert output_type == 'sentence'
35+
json_doc = [trankit_output]
36+
37+
conllu_doc = []
38+
for sentence in json_doc:
39+
conllu_sentence = []
40+
for token in sentence[TOKENS]:
41+
if type(token[ID]) == int or len(token[ID]) == 1:
42+
conllu_sentence.append(token)
43+
else:
44+
conllu_sentence.append(token)
45+
for word in token[EXPANDED]:
46+
conllu_sentence.append(word)
47+
conllu_doc.append(conllu_sentence)
48+
49+
return CoNLL.dict2conllstring(conllu_doc)
50+
except:
51+
print('Unsuccessful conversion! Please check the format of `trankit_output`')
52+
return None
53+
54+
2155
def remove_with_path(path):
2256
if os.path.exists(path):
2357
if os.path.isdir(path):
@@ -62,7 +96,8 @@ def download(cache_dir, language, saved_model_version, embedding_name): # put a
6296
save_fpath = os.path.join(lang_dir, '{}.zip'.format(language))
6397

6498
if not os.path.exists(os.path.join(lang_dir, '{}.downloaded'.format(language))):
65-
url = "http://nlp.uoregon.edu/download/trankit/{}/{}/{}.zip".format(saved_model_version, embedding_name, language)
99+
url = "http://nlp.uoregon.edu/download/trankit/{}/{}/{}.zip".format(saved_model_version, embedding_name,
100+
language)
66101
print(url)
67102

68103
response = requests.get(url, stream=True)

trankit/utils/conll.py

+8
Original file line numberDiff line numberDiff line change
@@ -166,3 +166,11 @@ def dict2conll(doc_dict, filename):
166166
conll_string = CoNLL.conll_as_string(doc_conll)
167167
with open(filename, 'w') as outfile:
168168
outfile.write(conll_string)
169+
170+
@staticmethod
171+
def dict2conllstring(doc_dict):
172+
""" Convert the dictionary format input data to the CoNLL-U format output data and write to a file.
173+
"""
174+
doc_conll = CoNLL.convert_dict(doc_dict)
175+
conll_string = CoNLL.conll_as_string(doc_conll)
176+
return conll_string

0 commit comments

Comments
 (0)