Skip to content

Commit 92a3ec0

Browse files
author
m.habedank
committed
removed torchtext
1 parent 8e90e70 commit 92a3ec0

File tree

1 file changed

+25
-317
lines changed

1 file changed

+25
-317
lines changed

ludwig/utils/tokenizers.py

+25-317
Original file line numberDiff line numberDiff line change
@@ -15,24 +15,16 @@
1515

1616
import logging
1717
from abc import abstractmethod
18-
from typing import Any, Dict, List, Optional, Union
18+
from typing import Any, List, Union
1919

2020
import torch
21-
import torchtext
2221

23-
from ludwig.constants import PADDING_SYMBOL, UNKNOWN_SYMBOL
24-
from ludwig.utils.data_utils import load_json
2522
from ludwig.utils.hf_utils import load_pretrained_hf_tokenizer
2623
from ludwig.utils.nlp_utils import load_nlp_pipeline, process_text
2724

2825
logger = logging.getLogger(__name__)
29-
torchtext_version = torch.torch_version.TorchVersion(torchtext.__version__)
3026

3127
TORCHSCRIPT_COMPATIBLE_TOKENIZERS = {"space", "space_punct", "comma", "underscore", "characters"}
32-
TORCHTEXT_0_12_0_TOKENIZERS = {"sentencepiece", "clip", "gpt2bpe"}
33-
TORCHTEXT_0_13_0_TOKENIZERS = {"bert"}
34-
35-
HF_TOKENIZER_SAMPLE_INPUTS = ["UNwant\u00E9d,running", "ah\u535A\u63A8zz", " \tHeLLo!how \n Are yoU? [UNK]"]
3628

3729

3830
class BaseTokenizer:
@@ -913,7 +905,7 @@ def convert_token_to_id(self, token: str) -> int:
913905

914906

915907
tokenizer_registry = {
916-
# Torchscript-compatible tokenizers. Torchtext tokenizers are also available below (requires torchtext>=0.12.0).
908+
# Torchscript-compatible tokenizers.
917909
"space": SpaceStringToListTokenizer,
918910
"space_punct": SpacePunctuationStringToListTokenizer,
919911
"ngram": NgramTokenizer,
@@ -1021,231 +1013,40 @@ def convert_token_to_id(self, token: str) -> int:
10211013
"multi_lemmatize_remove_stopwords": MultiLemmatizeRemoveStopwordsTokenizer,
10221014
}
10231015

1024-
"""torchtext 0.12.0 tokenizers.
1025-
1026-
Only available with torchtext>=0.12.0.
1027-
"""
1028-
1029-
1030-
class SentencePieceTokenizer(torch.nn.Module):
1031-
def __init__(self, **kwargs):
1032-
super().__init__()
1033-
self.tokenizer = load_pretrained_hf_tokenizer("FacebookAI/xlm-roberta-base")
1034-
1035-
def forward(self, v: Union[str, List[str], torch.Tensor]):
1036-
if isinstance(v, torch.Tensor):
1037-
raise ValueError(f"Unsupported input: {v}")
1038-
return self.tokenizer.tokenize(v)
1039-
1040-
1041-
class _BPETokenizer(torch.nn.Module):
1042-
"""Superclass for tokenizers that use BPE, such as CLIPTokenizer and GPT2BPETokenizer."""
1043-
1044-
def __init__(self, pretrained_model_name_or_path: str, vocab_file: str):
1045-
super().__init__()
1046-
self.str2idx, self.idx2str = self._init_vocab(vocab_file)
1047-
self.tokenizer = self._init_tokenizer(pretrained_model_name_or_path, vocab_file)
1048-
1049-
def _init_vocab(self, vocab_file: str) -> Dict[str, str]:
1050-
"""Loads the vocab from the vocab file."""
1051-
str2idx = load_json(torchtext.utils.get_asset_local_path(vocab_file))
1052-
_, idx2str = zip(*sorted((v, k) for k, v in str2idx.items()))
1053-
return str2idx, idx2str
1054-
1055-
def _init_tokenizer(self, pretrained_model_name_or_path: str, vocab_file: str) -> Any:
1056-
"""Initializes and returns the tokenizer."""
1057-
raise NotImplementedError
1058-
1059-
def forward(self, v: Union[str, List[str], torch.Tensor]) -> Any:
1060-
"""Implements forward pass for tokenizer.
1061-
1062-
BPE tokenizers from torchtext return ids directly, which is inconsistent with the Ludwig tokenizer API. The
1063-
below implementation works around this by converting the ids back to their original string tokens.
1064-
"""
1065-
if isinstance(v, torch.Tensor):
1066-
raise ValueError(f"Unsupported input: {v}")
1067-
1068-
inputs: List[str] = []
1069-
# Ludwig calls map on List[str] objects, so we need to handle individual strings as well.
1070-
if isinstance(v, str):
1071-
inputs.append(v)
1072-
else:
1073-
inputs.extend(v)
1074-
1075-
token_ids = self.tokenizer(inputs)
1076-
assert torch.jit.isinstance(token_ids, List[List[str]])
1077-
1078-
tokens = [[self.idx2str[int(unit_idx)] for unit_idx in sequence] for sequence in token_ids]
1079-
return tokens[0] if isinstance(v, str) else tokens
1080-
1081-
def get_vocab(self) -> Dict[str, str]:
1082-
return self.str2idx
10831016

1017+
class HFTokenizerShortcutFactory:
1018+
"""This factory can be used to build HuggingFace tokenizers form a shortcut string.
10841019
1085-
class CLIPTokenizer(_BPETokenizer):
1086-
def __init__(self, pretrained_model_name_or_path: Optional[str] = None, vocab_file: Optional[str] = None, **kwargs):
1087-
if pretrained_model_name_or_path is None:
1088-
pretrained_model_name_or_path = "http://download.pytorch.org/models/text/clip_merges.bpe"
1089-
if vocab_file is None:
1090-
vocab_file = "http://download.pytorch.org/models/text/clip_encoder.json"
1091-
super().__init__(pretrained_model_name_or_path, vocab_file)
1092-
1093-
def _init_tokenizer(self, pretrained_model_name_or_path: str, vocab_file: str):
1094-
return torchtext.transforms.CLIPTokenizer(
1095-
encoder_json_path=vocab_file, merges_path=pretrained_model_name_or_path
1096-
)
1097-
1098-
1099-
class GPT2BPETokenizer(_BPETokenizer):
1100-
def __init__(self, pretrained_model_name_or_path: Optional[str] = None, vocab_file: Optional[str] = None, **kwargs):
1101-
if pretrained_model_name_or_path is None:
1102-
pretrained_model_name_or_path = "https://download.pytorch.org/models/text/gpt2_bpe_vocab.bpe"
1103-
if vocab_file is None:
1104-
vocab_file = "https://download.pytorch.org/models/text/gpt2_bpe_encoder.json"
1105-
super().__init__(pretrained_model_name_or_path, vocab_file)
1106-
1107-
def _init_tokenizer(self, pretrained_model_name_or_path: str, vocab_file: str):
1108-
return torchtext.transforms.GPT2BPETokenizer(
1109-
encoder_json_path=vocab_file, vocab_bpe_path=pretrained_model_name_or_path
1110-
)
1111-
1020+
Those shortcuts were originally used for torchtext tokenizers. They also guarantee backward compatibility.
1021+
"""
11121022

1113-
tokenizer_registry.update(
1114-
{
1115-
"sentencepiece": SentencePieceTokenizer,
1116-
"clip": CLIPTokenizer,
1117-
"gpt2bpe": GPT2BPETokenizer,
1023+
MODELS = {
1024+
"sentencepiece": "FacebookAI/xlm-roberta-base",
1025+
"clip": "openai/clip-vit-base-patch32",
1026+
"gpt2bpe": "openai-community/gpt2",
1027+
"bert": "bert-base-uncased",
11181028
}
1119-
)
1120-
TORCHSCRIPT_COMPATIBLE_TOKENIZERS.update(TORCHTEXT_0_12_0_TOKENIZERS)
11211029

1030+
@classmethod
1031+
def create_class(cls, model_name: str):
1032+
"""Creating a tokenizer class from a model name."""
11221033

1123-
class BERTTokenizer(torch.nn.Module):
1124-
def __init__(
1125-
self,
1126-
vocab_file: Optional[str] = None,
1127-
is_hf_tokenizer: Optional[bool] = False,
1128-
hf_tokenizer_attrs: Optional[Dict[str, Any]] = None,
1129-
**kwargs,
1130-
):
1131-
super().__init__()
1132-
1133-
if vocab_file is None:
1134-
# If vocab_file not passed in, use default "bert-base-uncased" vocab and kwargs.
1135-
kwargs = _get_bert_config("bert-base-uncased")
1136-
vocab_file = kwargs["vocab_file"]
1137-
vocab = self._init_vocab(vocab_file)
1138-
hf_tokenizer_attrs = {
1139-
"pad_token": "[PAD]",
1140-
"unk_token": "[UNK]",
1141-
"sep_token_id": vocab["[SEP]"],
1142-
"cls_token_id": vocab["[CLS]"],
1143-
}
1144-
else:
1145-
vocab = self._init_vocab(vocab_file)
1146-
1147-
self.vocab = vocab
1148-
1149-
self.is_hf_tokenizer = is_hf_tokenizer
1150-
if self.is_hf_tokenizer:
1151-
# Values used by Ludwig extracted from the corresponding HF model.
1152-
self.pad_token = hf_tokenizer_attrs["pad_token"] # Used as padding symbol
1153-
self.unk_token = hf_tokenizer_attrs["unk_token"] # Used as unknown symbol
1154-
self.cls_token_id = hf_tokenizer_attrs["cls_token_id"] # Used as start symbol. Only used if HF.
1155-
self.sep_token_id = hf_tokenizer_attrs["sep_token_id"] # Used as stop symbol. Only used if HF.
1156-
self.never_split = hf_tokenizer_attrs["all_special_tokens"]
1157-
else:
1158-
self.pad_token = PADDING_SYMBOL
1159-
self.unk_token = UNKNOWN_SYMBOL
1160-
self.cls_token_id = None
1161-
self.sep_token_id = None
1162-
self.never_split = [UNKNOWN_SYMBOL]
1163-
1164-
tokenizer_kwargs = {}
1165-
if "do_lower_case" in kwargs:
1166-
tokenizer_kwargs["do_lower_case"] = kwargs["do_lower_case"]
1167-
if "strip_accents" in kwargs:
1168-
tokenizer_kwargs["strip_accents"] = kwargs["strip_accents"]
1169-
1170-
# Return tokens as raw tokens only if not being used as a HF tokenizer.
1171-
self.return_tokens = not self.is_hf_tokenizer
1172-
1173-
tokenizer_init_kwargs = {
1174-
**tokenizer_kwargs,
1175-
"vocab_path": vocab_file,
1176-
"return_tokens": self.return_tokens,
1177-
}
1178-
if torchtext_version >= (0, 14, 0):
1179-
# never_split kwarg added in torchtext 0.14.0
1180-
tokenizer_init_kwargs["never_split"] = self.never_split
1181-
1182-
self.tokenizer = torchtext.transforms.BERTTokenizer(**tokenizer_init_kwargs)
1183-
1184-
def _init_vocab(self, vocab_file: str) -> Dict[str, int]:
1185-
from transformers.models.bert.tokenization_bert import load_vocab
1186-
1187-
return load_vocab(vocab_file)
1188-
1189-
def forward(self, v: Union[str, List[str], torch.Tensor]) -> Any:
1190-
"""Implements forward pass for tokenizer.
1191-
1192-
If the is_hf_tokenizer flag is set to True, then the output follows the HF convention, i.e. the output is an
1193-
List[List[int]] of tokens and the cls and sep tokens are automatically added as the start and stop symbols.
1194-
1195-
If the is_hf_tokenizer flag is set to False, then the output follows the Ludwig convention, i.e. the output
1196-
is a List[List[str]] of tokens.
1197-
"""
1198-
if isinstance(v, torch.Tensor):
1199-
raise ValueError(f"Unsupported input: {v}")
1200-
1201-
inputs: List[str] = []
1202-
# Ludwig calls map on List[str] objects, so we need to handle individual strings as well.
1203-
if isinstance(v, str):
1204-
inputs.append(v)
1205-
else:
1206-
inputs.extend(v)
1207-
1208-
if self.is_hf_tokenizer:
1209-
token_ids_str = self.tokenizer(inputs)
1210-
assert torch.jit.isinstance(token_ids_str, List[List[str]])
1211-
# Must cast token_ids to ints because they are used directly as indices.
1212-
token_ids: List[List[int]] = []
1213-
for token_ids_str_i in token_ids_str:
1214-
token_ids_i = [int(token_id_str) for token_id_str in token_ids_str_i]
1215-
token_ids_i = self._add_special_token_ids(token_ids_i)
1216-
token_ids.append(token_ids_i)
1217-
return token_ids[0] if isinstance(v, str) else token_ids
1218-
1219-
tokens = self.tokenizer(inputs)
1220-
assert torch.jit.isinstance(tokens, List[List[str]])
1221-
return tokens[0] if isinstance(v, str) else tokens
1034+
class DynamicHFTokenizer(torch.nn.Module):
1035+
def __init__(self, **kwargs):
1036+
super().__init__()
1037+
self.tokenizer = load_pretrained_hf_tokenizer(model_name, use_fast=False)
12221038

1223-
def get_vocab(self) -> Dict[str, int]:
1224-
return self.vocab
1039+
def forward(self, v: Union[str, List[str], torch.Tensor]):
1040+
if isinstance(v, torch.Tensor):
1041+
raise ValueError(f"Unsupported input: {v}")
1042+
return self.tokenizer.tokenize(v)
12251043

1226-
def get_pad_token(self) -> str:
1227-
return self.pad_token
1228-
1229-
def get_unk_token(self) -> str:
1230-
return self.unk_token
1231-
1232-
def _add_special_token_ids(self, token_ids: List[int]) -> List[int]:
1233-
"""Adds special token ids to the token_ids list."""
1234-
if torch.jit.isinstance(self.cls_token_id, int) and torch.jit.isinstance(self.sep_token_id, int):
1235-
token_ids.insert(0, self.cls_token_id)
1236-
token_ids.append(self.sep_token_id)
1237-
return token_ids
1238-
1239-
def convert_token_to_id(self, token: str) -> int:
1240-
return self.vocab[token]
1044+
return DynamicHFTokenizer
12411045

12421046

12431047
tokenizer_registry.update(
1244-
{
1245-
"bert": BERTTokenizer,
1246-
}
1048+
{name: HFTokenizerShortcutFactory.create_class(model) for name, model in HFTokenizerShortcutFactory.MODELS.items()}
12471049
)
1248-
TORCHSCRIPT_COMPATIBLE_TOKENIZERS.update(TORCHTEXT_0_13_0_TOKENIZERS)
12491050

12501051

12511052
def get_hf_tokenizer(pretrained_model_name_or_path, **kwargs):
@@ -1256,82 +1057,8 @@ def get_hf_tokenizer(pretrained_model_name_or_path, **kwargs):
12561057
Returns:
12571058
A torchscript-able HF tokenizer if it is available. Else, returns vanilla HF tokenizer.
12581059
"""
1259-
from transformers import BertTokenizer, DistilBertTokenizer, ElectraTokenizer
1260-
1261-
# HuggingFace has implemented a DO Repeat Yourself policy for models
1262-
# https://github.com/huggingface/transformers/issues/19303
1263-
# We now need to manually track BERT-like tokenizers to map onto the TorchText implementation
1264-
# until PyTorch improves TorchScript to be able to compile HF tokenizers. This would require
1265-
# 1. Support for string inputs for torch.jit.trace, or
1266-
# 2. Support for `kwargs` in torch.jit.script
1267-
# This is populated in the `get_hf_tokenizer` since the set requires `transformers` to be installed
1268-
HF_BERTLIKE_TOKENIZER_CLS_SET = {BertTokenizer, DistilBertTokenizer, ElectraTokenizer}
1269-
1270-
hf_name = pretrained_model_name_or_path
1271-
# use_fast=False to leverage python class inheritance
1272-
# cannot tokenize HF tokenizers directly because HF lacks strict typing and List[str] cannot be traced
1273-
hf_tokenizer = load_pretrained_hf_tokenizer(hf_name, use_fast=False)
1274-
1275-
torchtext_tokenizer = None
1276-
if "bert" in TORCHSCRIPT_COMPATIBLE_TOKENIZERS and any(
1277-
isinstance(hf_tokenizer, cls) for cls in HF_BERTLIKE_TOKENIZER_CLS_SET
1278-
):
1279-
tokenizer_kwargs = _get_bert_config(hf_name)
1280-
torchtext_tokenizer = BERTTokenizer(
1281-
**tokenizer_kwargs,
1282-
is_hf_tokenizer=True,
1283-
hf_tokenizer_attrs={
1284-
"pad_token": hf_tokenizer.pad_token,
1285-
"unk_token": hf_tokenizer.unk_token,
1286-
"cls_token_id": hf_tokenizer.cls_token_id,
1287-
"sep_token_id": hf_tokenizer.sep_token_id,
1288-
"all_special_tokens": hf_tokenizer.all_special_tokens,
1289-
},
1290-
)
1291-
1292-
use_torchtext = torchtext_tokenizer is not None
1293-
if use_torchtext:
1294-
# If a torchtext tokenizer is instantiable, tenatively we will use it. However,
1295-
# if the tokenizer does not pass (lightweight) validation, then we will fall back to the vanilla HF tokenizer.
1296-
# TODO(geoffrey): can we better validate tokenizer parity before swapping in the TorchText tokenizer?
1297-
# Samples from https://github.com/huggingface/transformers/blob/main/tests/models/bert/test_tokenization_bert.py
1298-
for sample_input in HF_TOKENIZER_SAMPLE_INPUTS:
1299-
hf_output = hf_tokenizer.encode(sample_input)
1300-
tt_output = torchtext_tokenizer(sample_input)
1301-
if hf_output != tt_output:
1302-
use_torchtext = False
1303-
logger.warning("Falling back to HuggingFace tokenizer because TorchText tokenizer failed validation.")
1304-
logger.warning(f"Sample input: {sample_input}\nHF output: {hf_output}\nTT output: {tt_output}")
1305-
break
1306-
1307-
if use_torchtext:
1308-
logger.info(f"Loaded TorchText implementation of {hf_name} tokenizer")
1309-
return torchtext_tokenizer
1310-
else:
1311-
# If hf_name does not have a torchtext equivalent implementation, load the
1312-
# HuggingFace implementation.
1313-
logger.info(f"Loaded HuggingFace implementation of {hf_name} tokenizer")
1314-
return HFTokenizer(hf_name)
1315-
1316-
1317-
def _get_bert_config(hf_name):
1318-
"""Gets configs from BERT tokenizers in HuggingFace.
1319-
1320-
`vocab_file` is required for BERT tokenizers. `tokenizer_config.json` are optional keyword arguments used to
1321-
initialize the tokenizer object. If no `tokenizer_config.json` is found, then we instantiate the tokenizer with
1322-
default arguments.
1323-
"""
1324-
from huggingface_hub import hf_hub_download
1325-
from huggingface_hub.utils import EntryNotFoundError
1326-
1327-
vocab_file = hf_hub_download(repo_id=hf_name, filename="vocab.txt")
1328-
1329-
try:
1330-
tokenizer_config = load_json(hf_hub_download(repo_id=hf_name, filename="tokenizer_config.json"))
1331-
except EntryNotFoundError:
1332-
tokenizer_config = {}
13331060

1334-
return {"vocab_file": vocab_file, **tokenizer_config}
1061+
return HFTokenizer(pretrained_model_name_or_path)
13351062

13361063

13371064
tokenizer_registry.update(
@@ -1349,24 +1076,5 @@ def get_tokenizer_from_registry(tokenizer_name: str) -> torch.nn.Module:
13491076
"""
13501077
if tokenizer_name in tokenizer_registry:
13511078
return tokenizer_registry[tokenizer_name]
1352-
1353-
if (
1354-
torch.torch_version.TorchVersion(torchtext.__version__) < (0, 12, 0)
1355-
and tokenizer_name in TORCHTEXT_0_12_0_TOKENIZERS
1356-
):
1357-
raise KeyError(
1358-
f"torchtext>=0.12.0 is not installed, so '{tokenizer_name}' and the following tokenizers are not "
1359-
f"available: {TORCHTEXT_0_12_0_TOKENIZERS}"
1360-
)
1361-
1362-
if (
1363-
torch.torch_version.TorchVersion(torchtext.__version__) < (0, 13, 0)
1364-
and tokenizer_name in TORCHTEXT_0_13_0_TOKENIZERS
1365-
):
1366-
raise KeyError(
1367-
f"torchtext>=0.13.0 is not installed, so '{tokenizer_name}' and the following tokenizers are not "
1368-
f"available: {TORCHTEXT_0_13_0_TOKENIZERS}"
1369-
)
1370-
13711079
# Tokenizer does not exist or is unavailable.
13721080
raise KeyError(f"Invalid tokenizer name: '{tokenizer_name}'. Available tokenizers: {tokenizer_registry.keys()}")

0 commit comments

Comments
 (0)