Skip to content

Commit cab88be

Browse files
add chart2table (#3941)
1 parent 1b904cd commit cab88be

File tree

15 files changed

+3355
-31
lines changed

15 files changed

+3355
-31
lines changed

.precommit/check_imports.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
"shapely": "shapely",
7676
"soundfile": "soundfile",
7777
"starlette": "starlette",
78+
"tiktoken": "tiktoken",
7879
"tokenizers": "tokenizers",
7980
"tqdm": "tqdm",
8081
"typing_extensions": "typing-extensions",

paddlex/inference/common/batch_sampler/doc_vlm_batch_sampler.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,26 @@
1818

1919

2020
class DocVLMBatchSampler(BaseBatchSampler):
21-
def __init__(self):
21+
22+
model_names_only_supports_batchsize_of_one = {"PP-DocBee-2B", "PP-DocBee-7B"}
23+
24+
def __init__(self, model_name, batch_size: int = 1) -> None:
2225
"""Initializes the BaseBatchSampler.
2326
2427
Args:
28+
model_name (str): The name of the model.
2529
batch_size (int, optional): The size of each batch. Only support 1.
2630
"""
27-
super().__init__()
28-
self.batch_size = 1
31+
self.model_name = model_name
32+
if (
33+
self.model_name in self.model_names_only_supports_batchsize_of_one
34+
and batch_size != 1
35+
):
36+
logging.warning(
37+
f"doc vlm batch sampler only support batch size 1 for {self.model_name}, but got {batch_size} and it will not take effect."
38+
)
39+
batch_size = 1
40+
super().__init__(batch_size)
2941

3042
def sample(self, inputs):
3143
"""Generate list of input file path.
@@ -37,14 +49,22 @@ def sample(self, inputs):
3749
list: list of file path.
3850
"""
3951
if isinstance(inputs, dict):
40-
yield [inputs]
41-
elif isinstance(inputs, list) and all(isinstance(i, dict) for i in inputs):
42-
yield inputs
43-
else:
52+
inputs = [inputs]
53+
if not (isinstance(inputs, list) and all(isinstance(i, dict) for i in inputs)):
4454
raise TypeError(
45-
f"Not supported input data type! Only `dict` are supported, but got: {type(inputs)}."
55+
f"Not supported input data type! Only `Dict` or `List[Dict]` are supported, but got: {type(inputs)}."
4656
)
4757

58+
batch = []
59+
for input_ in inputs:
60+
batch.append(input_)
61+
if len(batch) == self.batch_size:
62+
yield batch
63+
batch = []
64+
65+
if len(batch) > 0:
66+
yield batch
67+
4868
@BaseBatchSampler.batch_size.setter
4969
def batch_size(self, batch_size):
5070
"""Sets the batch size.
@@ -56,9 +76,12 @@ def batch_size(self, batch_size):
5676
Warning: If the batch size is not equal 1.
5777
"""
5878
# only support batch size 1
59-
if batch_size != 1:
79+
if (
80+
self.model_name in self.model_names_only_supports_batchsize_of_one
81+
and batch_size != 1
82+
):
6083
logging.warning(
61-
f"doc vlm batch sampler only support batch size 1, but got {batch_size}."
84+
f"doc vlm batch sampler only support batch size 1 for {self.model_name}, but got {batch_size} and it will not take effect."
6285
)
6386
else:
6487
self._batch_size = batch_size

paddlex/inference/models/common/tokenizer/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,5 @@
1616
from .clip_tokenizer import CLIPTokenizer
1717
from .gpt_tokenizer import GPTTokenizer
1818
from .qwen2_tokenizer import MIXQwen2Tokenizer, Qwen2Tokenizer
19+
from .qwen_tokenizer import QWenTokenizer
1920
from .tokenizer_utils import PretrainedTokenizer
Lines changed: 288 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,288 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import base64
16+
import importlib.util
17+
import os
18+
import unicodedata
19+
from typing import Collection, Dict, List, Set, Tuple, Union
20+
21+
from .tokenizer_utils import PretrainedTokenizer
22+
from .tokenizer_utils_base import AddedToken
23+
24+
__all__ = ["QWenTokenizer"]
25+
26+
27+
VOCAB_FILES_NAMES = {"vocab_file": "qwen.tiktoken"}
28+
29+
PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
30+
ENDOFTEXT = "<|endoftext|>"
31+
IMSTART = "<|im_start|>"
32+
IMEND = "<|im_end|>"
33+
# as the default behavior is changed to allow special tokens in
34+
# regular texts, the surface forms of special tokens need to be
35+
# as different as possible to minimize the impact
36+
EXTRAS = tuple((f"<|extra_{i}|>" for i in range(205)))
37+
SPECIAL_TOKENS = (
38+
ENDOFTEXT,
39+
IMSTART,
40+
IMEND,
41+
) + EXTRAS
42+
43+
tiktoken = None
44+
45+
46+
def is_tiktoken_available():
47+
return importlib.util.find_spec("tiktoken") is not None
48+
49+
50+
def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
51+
with open(tiktoken_bpe_file, "rb") as f:
52+
contents = f.read()
53+
return {
54+
base64.b64decode(token): int(rank)
55+
for token, rank in (line.split() for line in contents.splitlines() if line)
56+
}
57+
58+
59+
class QWenTokenizer(PretrainedTokenizer):
60+
"""QWen tokenizer."""
61+
62+
model_input_names = ["input_ids", "attention_mask", "position_ids"]
63+
resource_files_names = VOCAB_FILES_NAMES
64+
65+
def __init__(
66+
self,
67+
vocab_file,
68+
errors="replace",
69+
padding_side="left",
70+
**kwargs,
71+
):
72+
super().__init__(**kwargs)
73+
if not is_tiktoken_available():
74+
raise ValueError(
75+
"tiktoken is not installed, please install it use: pip install tiktoken"
76+
)
77+
78+
import tiktoken as tk
79+
80+
tiktoken = tk
81+
82+
self.errors = errors # how to handle errors in decoding
83+
84+
self.mergeable_ranks = _load_tiktoken_bpe(vocab_file) # type: dict[bytes, int]
85+
self.special_tokens = {
86+
token: index
87+
for index, token in enumerate(
88+
SPECIAL_TOKENS, start=len(self.mergeable_ranks)
89+
)
90+
}
91+
92+
enc = tiktoken.Encoding(
93+
"Qwen",
94+
pat_str=PAT_STR,
95+
mergeable_ranks=self.mergeable_ranks,
96+
special_tokens=self.special_tokens,
97+
)
98+
assert (
99+
len(self.mergeable_ranks) + len(self.special_tokens) == enc.n_vocab
100+
), f"{len(self.mergeable_ranks) + len(self.special_tokens)} != {enc.n_vocab} in encoding"
101+
102+
self.decoder = {
103+
v: k for k, v in self.mergeable_ranks.items()
104+
} # type: dict[int, bytes|str]
105+
self.decoder.update({v: k for k, v in self.special_tokens.items()})
106+
107+
self.tokenizer = enc # type: tiktoken.Encoding
108+
109+
self.eod_id = self.tokenizer.eot_token
110+
self.im_start_id = self.special_tokens[IMSTART]
111+
self.im_end_id = self.special_tokens[IMEND]
112+
113+
if "pad_token_id" in kwargs:
114+
self.pad_token_id = kwargs["pad_token_id"]
115+
if "eos_token_id" in kwargs:
116+
self.eos_token_id = kwargs["eos_token_id"]
117+
118+
def __len__(self) -> int:
119+
return self.tokenizer.n_vocab
120+
121+
def get_vocab(self) -> Dict[bytes, int]:
122+
return self.mergeable_ranks
123+
124+
def convert_tokens_to_ids(
125+
self, tokens: Union[bytes, str, List[Union[bytes, str]]]
126+
) -> List[int]:
127+
ids = []
128+
if isinstance(tokens, (str, bytes)):
129+
if tokens in self.special_tokens:
130+
return self.special_tokens[tokens]
131+
else:
132+
return self.mergeable_ranks.get(tokens)
133+
for token in tokens:
134+
if token in self.special_tokens:
135+
ids.append(self.special_tokens[token])
136+
else:
137+
ids.append(self.mergeable_ranks.get(token))
138+
return ids
139+
140+
def _update_tiktoken(self, tokens: List[str], special_tokens: bool = False) -> int:
141+
if special_tokens:
142+
added_tokens = []
143+
for token in tokens:
144+
if token in self.special_tokens:
145+
continue
146+
147+
token_id = len(self.mergeable_ranks) + len(self.special_tokens)
148+
self.special_tokens[token] = token_id
149+
self.decoder[token_id] = token
150+
151+
added_tokens.append(token)
152+
153+
import tiktoken
154+
155+
self.tokenizer = tiktoken.Encoding(
156+
"Qwen",
157+
pat_str=PAT_STR,
158+
mergeable_ranks=self.mergeable_ranks,
159+
special_tokens=self.special_tokens,
160+
)
161+
162+
return len(added_tokens)
163+
else:
164+
raise ValueError("Adding regular tokens is not supported")
165+
166+
def _add_tokens(
167+
self,
168+
new_tokens: Union[List[str], List[AddedToken]],
169+
special_tokens: bool = False,
170+
) -> int:
171+
if not special_tokens and new_tokens:
172+
raise ValueError("Adding regular tokens is not supported")
173+
new_tokens_str = []
174+
for token in new_tokens:
175+
surface_form = token.content if isinstance(token, AddedToken) else token
176+
new_tokens_str.append(surface_form)
177+
178+
return self._update_tiktoken(new_tokens_str, special_tokens)
179+
180+
def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]:
181+
"""
182+
Save only the vocabulary of the tokenizer (vocabulary).
183+
184+
Returns:
185+
`Tuple(str)`: Paths to the files saved.
186+
"""
187+
file_path = os.path.join(save_directory, "qwen.tiktoken")
188+
with open(file_path, "w", encoding="utf8") as w:
189+
for k, v in self.mergeable_ranks.items():
190+
line = base64.b64encode(k).decode("utf8") + " " + str(v) + "\n"
191+
w.write(line)
192+
return (file_path,)
193+
194+
def tokenize(
195+
self,
196+
text: str,
197+
allowed_special: Union[Set, str] = "all",
198+
disallowed_special: Union[Collection, str] = (),
199+
**kwargs,
200+
) -> List[Union[bytes, str]]:
201+
"""
202+
Converts a string in a sequence of tokens.
203+
204+
Args:
205+
text (`str`):
206+
The sequence to be encoded.
207+
allowed_special (`Literal["all"]` or `set`):
208+
The surface forms of the tokens to be encoded as special tokens in regular texts.
209+
Default to "all".
210+
disallowed_special (`Literal["all"]` or `Collection`):
211+
The surface forms of the tokens that should not be in regular texts and trigger errors.
212+
Default to an empty tuple.
213+
214+
kwargs (additional keyword arguments, *optional*):
215+
Will be passed to the underlying model specific encode method.
216+
217+
Returns:
218+
`List[bytes|str]`: The list of tokens.
219+
"""
220+
tokens = []
221+
text = unicodedata.normalize("NFC", text)
222+
223+
# this implementation takes a detour: text -> token id -> token surface forms
224+
for t in self.tokenizer.encode(
225+
text, allowed_special=allowed_special, disallowed_special=disallowed_special
226+
):
227+
tokens.append(self.decoder[t])
228+
return tokens
229+
230+
def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str:
231+
"""
232+
Converts a sequence of tokens in a single string.
233+
"""
234+
text = ""
235+
temp = b""
236+
for t in tokens:
237+
if isinstance(t, str):
238+
if temp:
239+
text += temp.decode("utf-8", errors=self.errors)
240+
temp = b""
241+
text += t
242+
elif isinstance(t, bytes):
243+
temp += t
244+
else:
245+
raise TypeError("token should only be of type types or str")
246+
if temp:
247+
text += temp.decode("utf-8", errors=self.errors)
248+
return text
249+
250+
@property
251+
def vocab_size(self):
252+
return self.tokenizer.n_vocab
253+
254+
def _convert_id_to_token(self, index: int) -> Union[bytes, str]:
255+
"""Converts an id to a token, special tokens included"""
256+
if index in self.decoder:
257+
return self.decoder[index]
258+
raise ValueError("unknown ids")
259+
260+
def _convert_token_to_id(self, token: Union[bytes, str]) -> int:
261+
"""Converts a token to an id using the vocab, special tokens included"""
262+
if token in self.special_tokens:
263+
return self.special_tokens[token]
264+
if token in self.mergeable_ranks:
265+
return self.mergeable_ranks[token]
266+
raise ValueError("unknown token")
267+
268+
def _tokenize(self, text: str, **kwargs):
269+
"""
270+
Converts a string in a sequence of tokens (string), using the tokenizer. Split in words for word-based
271+
vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces).
272+
273+
Do NOT take care of added tokens.
274+
"""
275+
raise NotImplementedError
276+
277+
def _decode(
278+
self,
279+
token_ids: Union[int, List[int]],
280+
skip_special_tokens: bool = False,
281+
errors: str = None,
282+
**kwargs,
283+
) -> str:
284+
if isinstance(token_ids, int):
285+
token_ids = [token_ids]
286+
if skip_special_tokens:
287+
token_ids = [i for i in token_ids if i < self.eod_id]
288+
return self.tokenizer.decode(token_ids, errors=errors or self.errors)

0 commit comments

Comments
 (0)