Skip to content

Commit 1134756

Browse files
authored
add support for modular tokenizer batch encoding (#399)
1 parent 710fa74 commit 1134756

4 files changed

Lines changed: 116 additions & 50 deletions

File tree

fuse/data/tokenizers/modular_tokenizer/modular_tokenizer.py

Lines changed: 109 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,12 @@
1111

1212
import omegaconf
1313
import tokenizers
14+
import torch
1415
import transformers
1516
from omegaconf import OmegaConf
1617
from tokenizers import Encoding, Tokenizer
1718
from torch import Tensor
19+
from transformers import BatchEncoding
1820

1921
from fuse.data.tokenizers.modular_tokenizer.special_tokens import special_wrap_input
2022

@@ -53,7 +55,7 @@ def list_to_tokenizer_string(lst: List[ModularTokenizerInput]) -> str:
5355
return out
5456

5557

56-
class ModularTokenizer(transformers.PreTrainedTokenizerFast):
58+
class ModularTokenizer(transformers.PreTrainedTokenizerBase):
5759
def __init__(
5860
self,
5961
tokenizers_info: Union[List, omegaconf.listconfig.ListConfig],
@@ -91,8 +93,7 @@ def __init__(
9193
If it is not set, new special tokens may be mapped to IDs higher that regular token IDs. If Defaults to None (i.e. no limit is set).
9294
seed: random generator seed - used for random truncation in random truncation mode (not the default mode)
9395
"""
94-
# ModularTokenizer inherits the interface of PreTrainedTokenizerBase, but not the underlying logic, therefore super.__init__() is not called
95-
96+
super().__init__()
9697
# If there is only one tokenizer, remapping it is not needed - if there's only one, we can just load its json using load_from_jsons.
9798
if isinstance(tokenizers_info, omegaconf.listconfig.ListConfig) or isinstance(
9899
tokenizers_info, omegaconf.dictconfig.DictConfig
@@ -1602,47 +1603,6 @@ def enable_truncation(
16021603
assert direction == "right", "direction setting not implemented"
16031604
self.max_len = max_length
16041605

1605-
def encode_batch(
1606-
self,
1607-
input: List,
1608-
is_pretokenized: Optional[bool] = False,
1609-
add_special_tokens: Optional[bool] = True,
1610-
) -> List:
1611-
"""
1612-
Encode the given batch of inputs. This method accept both raw text sequences
1613-
as well as already pre-tokenized sequences.
1614-
1615-
Example:
1616-
Here are some examples of the inputs that are accepted::
1617-
1618-
encode_batch([
1619-
"A single sequence",
1620-
("A tuple with a sequence", "And its pair"),
1621-
[ "A", "pre", "tokenized", "sequence" ],
1622-
([ "A", "pre", "tokenized", "sequence" ], "And its pair")
1623-
])
1624-
1625-
Args:
1626-
input (A :obj:`List`/:obj:`Tuple` of :obj:`~tokenizers.EncodeInput`):
1627-
A list of single sequences or pair sequences to encode. Each sequence
1628-
can be either raw text or pre-tokenized, according to the ``is_pretokenized``
1629-
argument:
1630-
1631-
- If ``is_pretokenized=False``: :class:`~tokenizers.TextEncodeInput`
1632-
- If ``is_pretokenized=True``: :class:`~tokenizers.PreTokenizedEncodeInput`
1633-
1634-
is_pretokenized (:obj:`bool`, defaults to :obj:`False`):
1635-
Whether the input is already pre-tokenized
1636-
1637-
add_special_tokens (:obj:`bool`, defaults to :obj:`True`):
1638-
Whether to add the special tokens
1639-
1640-
Returns:
1641-
A :obj:`List` of :class:`~tokenizers.Encoding`: The encoded batch
1642-
1643-
"""
1644-
raise Exception("Not implemented")
1645-
16461606
@staticmethod
16471607
def from_buffer(buffer: object) -> object:
16481608
"""
@@ -2098,3 +2058,108 @@ def truncation(self) -> Optional[Dict]:
20982058
A dict with the current truncation parameters if truncation is enabled
20992059
"""
21002060
raise Exception("Not implemented")
2061+
2062+
def __repr__(self) -> str:
2063+
return (
2064+
f"{self.__class__.__name__},\n"
2065+
+ f"Subtokenizer = {[subtokenizer for subtokenizer in self.tokenizers_info.keys()]}" # " added_tokens_decoder={\n\t" + added_tokens_decoder_rep + "\n}\n)"
2066+
)
2067+
2068+
@property
2069+
def pad_token_id(self) -> Optional[int]:
2070+
"""
2071+
`Optional[int]`: Id of the padding token in the vocabulary. Returns `None` if the token has not been set.
2072+
"""
2073+
if self._pad_token is None:
2074+
return None
2075+
return self.token_to_id(self.pad_token)
2076+
2077+
def convert_tokens_to_ids(
2078+
self, *args: list, **kwargs: dict
2079+
) -> Union[List[int], int]:
2080+
enc_ids = self.encode(*args, **kwargs).ids
2081+
if len(enc_ids) == 1:
2082+
return enc_ids[0]
2083+
return enc_ids
2084+
2085+
def convert_ids_to_tokens(
2086+
self, ids: Union[List[int], int], *args: list, **kwargs: dict
2087+
) -> str:
2088+
if isinstance(ids, int):
2089+
ids = [ids]
2090+
return self.decode(ids=ids, *args, **kwargs)
2091+
2092+
def __call__(
2093+
self,
2094+
text: Union[str, List[str], List[List[str]]],
2095+
# text_pair: Optional[Union[str, List[str], List[List[str]]]] = None,
2096+
# add_special_tokens: bool = True,
2097+
padding: Union[
2098+
bool, str, transformers.tokenization_utils_base.PaddingStrategy
2099+
] = False,
2100+
truncation: Union[
2101+
bool, str, transformers.tokenization_utils_base.TruncationStrategy
2102+
] = False,
2103+
max_length: Optional[int] = None,
2104+
stride: int = 0,
2105+
# is_pretokenized: bool = False,
2106+
# pad_to_multiple_of: Optional[int] = None,
2107+
return_tensors: Optional[
2108+
Union[str, transformers.tokenization_utils_base.TensorType]
2109+
] = None,
2110+
# return_token_type_ids: Optional[bool] = None,
2111+
return_attention_mask: Optional[bool] = None,
2112+
# return_overflowing_tokens: bool = False,
2113+
# return_special_tokens_mask: bool = False,
2114+
# return_offsets_mapping: bool = False,
2115+
# return_length: bool = False,
2116+
verbose: bool = True,
2117+
**kwargs: dict,
2118+
) -> BatchEncoding:
2119+
"""
2120+
Main method to tokenize and prepare for the model one or several sequence(s) sequences.
2121+
The method supports subeset of the arguments defined in PreTrainedTokenizerBase
2122+
"""
2123+
assert return_tensors in [None, "pt"], f"Error: unsupported {return_tensors=}"
2124+
2125+
is_single_input = isinstance(text, str)
2126+
text = [text] if is_single_input else text
2127+
2128+
encoding_list = []
2129+
for sequence in text:
2130+
encoding = self.encode(sequence, **kwargs)
2131+
if truncation and max_length is not None:
2132+
encoding.truncate(max_length=max_length, stride=stride)
2133+
encoding_list.append(encoding)
2134+
2135+
# Padding
2136+
if padding:
2137+
max_len = (
2138+
max(len(encoding.ids) for encoding in encoding_list)
2139+
if max_length is None
2140+
else max_length
2141+
)
2142+
for encoding in encoding_list:
2143+
encoding.pad(
2144+
length=max_len, pad_id=self.pad_token_id, pad_token=self.pad_token
2145+
)
2146+
2147+
result = {
2148+
"input_ids": [encoding.ids for encoding in encoding_list],
2149+
}
2150+
if return_attention_mask is None or return_attention_mask:
2151+
result["attention_mask"] = [
2152+
encoding.attention_mask for encoding in encoding_list
2153+
]
2154+
2155+
if return_tensors == "pt":
2156+
result = {k: torch.tensor(v, dtype=torch.long) for k, v in result.items()}
2157+
2158+
# Return single sample if input was a string
2159+
if is_single_input:
2160+
result = {
2161+
k: v[0] if isinstance(v, list) and not isinstance(v[0], list) else v[0]
2162+
for k, v in result.items()
2163+
}
2164+
2165+
return BatchEncoding(result, tensor_type=return_tensors, encoding=encoding_list)

fuse/dl/lightning/pl_funcs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,8 @@ def epoch_end_compute_and_log_losses(
274274
losses[key].append(value)
275275

276276
for key in losses:
277-
loss = mean(losses[key])
277+
values = [x for x in losses[key] if x is not None]
278+
loss = mean(values)
278279
pl_module.log(
279280
f"{mode}{sep}losses.{key}",
280281
loss,

fuse/dl/tests/test_cat.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ def test_all_contexts(self) -> None:
3838
"emb_dim": 128,
3939
"num_tokens_a": 10000,
4040
"num_tokens_b": 20000,
41-
"max_seq_len_a": 512,
42-
"max_seq_len_b": 1024,
41+
"max_seq_len_a": 256,
42+
"max_seq_len_b": 512,
4343
"output_dim": 256,
4444
"kwargs_wrapper_a": dict(emb_dropout=0.1),
4545
"kwargs_wrapper_b": dict(emb_dropout=0.1),

fuse/utils/multiprocessing/run_multiprocessed.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def _store_in_global_storage(store_me: dict) -> None:
209209
if store_me is None:
210210
return
211211

212-
global _multiprocess_global_storage
212+
global _multiprocess_global_storage # noqa
213213

214214
# making sure there are no name conflicts
215215
for key, _ in store_me.items():
@@ -229,7 +229,7 @@ def _remove_from_global_storage(remove_me: List) -> None:
229229
if remove_me is None:
230230
return
231231

232-
global _multiprocess_global_storage
232+
global _multiprocess_global_storage # noqa
233233
for key in remove_me:
234234
del _multiprocess_global_storage[key]
235235

@@ -238,7 +238,7 @@ def get_from_global_storage(key: str) -> Any:
238238
"""
239239
Get args copied by run_multiprocessed
240240
"""
241-
global _multiprocess_global_storage
241+
global _multiprocess_global_storage # noqa
242242
return _multiprocess_global_storage[key]
243243

244244

0 commit comments

Comments
 (0)