|
11 | 11 |
|
12 | 12 | import omegaconf |
13 | 13 | import tokenizers |
| 14 | +import torch |
14 | 15 | import transformers |
15 | 16 | from omegaconf import OmegaConf |
16 | 17 | from tokenizers import Encoding, Tokenizer |
17 | 18 | from torch import Tensor |
| 19 | +from transformers import BatchEncoding |
18 | 20 |
|
19 | 21 | from fuse.data.tokenizers.modular_tokenizer.special_tokens import special_wrap_input |
20 | 22 |
|
@@ -53,7 +55,7 @@ def list_to_tokenizer_string(lst: List[ModularTokenizerInput]) -> str: |
53 | 55 | return out |
54 | 56 |
|
55 | 57 |
|
56 | | -class ModularTokenizer(transformers.PreTrainedTokenizerFast): |
| 58 | +class ModularTokenizer(transformers.PreTrainedTokenizerBase): |
57 | 59 | def __init__( |
58 | 60 | self, |
59 | 61 | tokenizers_info: Union[List, omegaconf.listconfig.ListConfig], |
@@ -91,8 +93,7 @@ def __init__( |
91 | 93 | If it is not set, new special tokens may be mapped to IDs higher that regular token IDs. If Defaults to None (i.e. no limit is set). |
92 | 94 | seed: random generator seed - used for random truncation in random truncation mode (not the default mode) |
93 | 95 | """ |
94 | | - # ModularTokenizer inherits the interface of PreTrainedTokenizerBase, but not the underlying logic, therefore super.__init__() is not called |
95 | | - |
| 96 | + super().__init__() |
96 | 97 | # If there is only one tokenizer, remapping it is not needed - if there's only one, we can just load its json using load_from_jsons. |
97 | 98 | if isinstance(tokenizers_info, omegaconf.listconfig.ListConfig) or isinstance( |
98 | 99 | tokenizers_info, omegaconf.dictconfig.DictConfig |
@@ -1602,47 +1603,6 @@ def enable_truncation( |
1602 | 1603 | assert direction == "right", "direction setting not implemented" |
1603 | 1604 | self.max_len = max_length |
1604 | 1605 |
|
1605 | | - def encode_batch( |
1606 | | - self, |
1607 | | - input: List, |
1608 | | - is_pretokenized: Optional[bool] = False, |
1609 | | - add_special_tokens: Optional[bool] = True, |
1610 | | - ) -> List: |
1611 | | - """ |
1612 | | - Encode the given batch of inputs. This method accept both raw text sequences |
1613 | | - as well as already pre-tokenized sequences. |
1614 | | -
|
1615 | | - Example: |
1616 | | - Here are some examples of the inputs that are accepted:: |
1617 | | -
|
1618 | | - encode_batch([ |
1619 | | - "A single sequence", |
1620 | | - ("A tuple with a sequence", "And its pair"), |
1621 | | - [ "A", "pre", "tokenized", "sequence" ], |
1622 | | - ([ "A", "pre", "tokenized", "sequence" ], "And its pair") |
1623 | | - ]) |
1624 | | -
|
1625 | | - Args: |
1626 | | - input (A :obj:`List`/:obj:`Tuple` of :obj:`~tokenizers.EncodeInput`): |
1627 | | - A list of single sequences or pair sequences to encode. Each sequence |
1628 | | - can be either raw text or pre-tokenized, according to the ``is_pretokenized`` |
1629 | | - argument: |
1630 | | -
|
1631 | | - - If ``is_pretokenized=False``: :class:`~tokenizers.TextEncodeInput` |
1632 | | - - If ``is_pretokenized=True``: :class:`~tokenizers.PreTokenizedEncodeInput` |
1633 | | -
|
1634 | | - is_pretokenized (:obj:`bool`, defaults to :obj:`False`): |
1635 | | - Whether the input is already pre-tokenized |
1636 | | -
|
1637 | | - add_special_tokens (:obj:`bool`, defaults to :obj:`True`): |
1638 | | - Whether to add the special tokens |
1639 | | -
|
1640 | | - Returns: |
1641 | | - A :obj:`List` of :class:`~tokenizers.Encoding`: The encoded batch |
1642 | | -
|
1643 | | - """ |
1644 | | - raise Exception("Not implemented") |
1645 | | - |
1646 | 1606 | @staticmethod |
1647 | 1607 | def from_buffer(buffer: object) -> object: |
1648 | 1608 | """ |
@@ -2098,3 +2058,108 @@ def truncation(self) -> Optional[Dict]: |
2098 | 2058 | A dict with the current truncation parameters if truncation is enabled |
2099 | 2059 | """ |
2100 | 2060 | raise Exception("Not implemented") |
| 2061 | + |
| 2062 | + def __repr__(self) -> str: |
| 2063 | + return ( |
| 2064 | + f"{self.__class__.__name__},\n" |
| 2065 | + + f"Subtokenizer = {[subtokenizer for subtokenizer in self.tokenizers_info.keys()]}" # " added_tokens_decoder={\n\t" + added_tokens_decoder_rep + "\n}\n)" |
| 2066 | + ) |
| 2067 | + |
| 2068 | + @property |
| 2069 | + def pad_token_id(self) -> Optional[int]: |
| 2070 | + """ |
| 2071 | + `Optional[int]`: Id of the padding token in the vocabulary. Returns `None` if the token has not been set. |
| 2072 | + """ |
| 2073 | + if self._pad_token is None: |
| 2074 | + return None |
| 2075 | + return self.token_to_id(self.pad_token) |
| 2076 | + |
| 2077 | + def convert_tokens_to_ids( |
| 2078 | + self, *args: list, **kwargs: dict |
| 2079 | + ) -> Union[List[int], int]: |
| 2080 | + enc_ids = self.encode(*args, **kwargs).ids |
| 2081 | + if len(enc_ids) == 1: |
| 2082 | + return enc_ids[0] |
| 2083 | + return enc_ids |
| 2084 | + |
| 2085 | + def convert_ids_to_tokens( |
| 2086 | + self, ids: Union[List[int], int], *args: list, **kwargs: dict |
| 2087 | + ) -> str: |
| 2088 | + if isinstance(ids, int): |
| 2089 | + ids = [ids] |
| 2090 | + return self.decode(ids=ids, *args, **kwargs) |
| 2091 | + |
| 2092 | + def __call__( |
| 2093 | + self, |
| 2094 | + text: Union[str, List[str], List[List[str]]], |
| 2095 | + # text_pair: Optional[Union[str, List[str], List[List[str]]]] = None, |
| 2096 | + # add_special_tokens: bool = True, |
| 2097 | + padding: Union[ |
| 2098 | + bool, str, transformers.tokenization_utils_base.PaddingStrategy |
| 2099 | + ] = False, |
| 2100 | + truncation: Union[ |
| 2101 | + bool, str, transformers.tokenization_utils_base.TruncationStrategy |
| 2102 | + ] = False, |
| 2103 | + max_length: Optional[int] = None, |
| 2104 | + stride: int = 0, |
| 2105 | + # is_pretokenized: bool = False, |
| 2106 | + # pad_to_multiple_of: Optional[int] = None, |
| 2107 | + return_tensors: Optional[ |
| 2108 | + Union[str, transformers.tokenization_utils_base.TensorType] |
| 2109 | + ] = None, |
| 2110 | + # return_token_type_ids: Optional[bool] = None, |
| 2111 | + return_attention_mask: Optional[bool] = None, |
| 2112 | + # return_overflowing_tokens: bool = False, |
| 2113 | + # return_special_tokens_mask: bool = False, |
| 2114 | + # return_offsets_mapping: bool = False, |
| 2115 | + # return_length: bool = False, |
| 2116 | + verbose: bool = True, |
| 2117 | + **kwargs: dict, |
| 2118 | + ) -> BatchEncoding: |
| 2119 | + """ |
| 2120 | + Main method to tokenize and prepare for the model one or several sequence(s) sequences. |
| 2121 | + The method supports subeset of the arguments defined in PreTrainedTokenizerBase |
| 2122 | + """ |
| 2123 | + assert return_tensors in [None, "pt"], f"Error: unsupported {return_tensors=}" |
| 2124 | + |
| 2125 | + is_single_input = isinstance(text, str) |
| 2126 | + text = [text] if is_single_input else text |
| 2127 | + |
| 2128 | + encoding_list = [] |
| 2129 | + for sequence in text: |
| 2130 | + encoding = self.encode(sequence, **kwargs) |
| 2131 | + if truncation and max_length is not None: |
| 2132 | + encoding.truncate(max_length=max_length, stride=stride) |
| 2133 | + encoding_list.append(encoding) |
| 2134 | + |
| 2135 | + # Padding |
| 2136 | + if padding: |
| 2137 | + max_len = ( |
| 2138 | + max(len(encoding.ids) for encoding in encoding_list) |
| 2139 | + if max_length is None |
| 2140 | + else max_length |
| 2141 | + ) |
| 2142 | + for encoding in encoding_list: |
| 2143 | + encoding.pad( |
| 2144 | + length=max_len, pad_id=self.pad_token_id, pad_token=self.pad_token |
| 2145 | + ) |
| 2146 | + |
| 2147 | + result = { |
| 2148 | + "input_ids": [encoding.ids for encoding in encoding_list], |
| 2149 | + } |
| 2150 | + if return_attention_mask is None or return_attention_mask: |
| 2151 | + result["attention_mask"] = [ |
| 2152 | + encoding.attention_mask for encoding in encoding_list |
| 2153 | + ] |
| 2154 | + |
| 2155 | + if return_tensors == "pt": |
| 2156 | + result = {k: torch.tensor(v, dtype=torch.long) for k, v in result.items()} |
| 2157 | + |
| 2158 | + # Return single sample if input was a string |
| 2159 | + if is_single_input: |
| 2160 | + result = { |
| 2161 | + k: v[0] if isinstance(v, list) and not isinstance(v[0], list) else v[0] |
| 2162 | + for k, v in result.items() |
| 2163 | + } |
| 2164 | + |
| 2165 | + return BatchEncoding(result, tensor_type=return_tensors, encoding=encoding_list) |
0 commit comments