Skip to content

Commit 534859c

Browse files
Protect the fast wordpiece tokenizer from infinite looping.
PiperOrigin-RevId: 705292812
1 parent 6365dba commit 534859c

File tree

3 files changed

+39
-13
lines changed

3 files changed

+39
-13
lines changed

tensorflow_text/core/kernels/fast_wordpiece_tokenizer.cc

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,13 @@ void FastWordpieceTokenizer::Tokenize(absl::string_view input,
6565
std::vector<int>* output_ids,
6666
std::vector<int>* output_start_offsets,
6767
std::vector<int>* output_end_offsets,
68-
int input_word_offset_in_text) const {
68+
int input_word_offset_in_text,
69+
bool* error) const {
6970
if (config_->end_to_end()) {
7071
TokenizeTextImpl</*kGetPieces=*/true, /*kGetIds=*/true,
7172
/*kGetOffsets=*/true>(input, output_pieces, output_ids,
7273
output_start_offsets,
73-
output_end_offsets);
74+
output_end_offsets, error);
7475
} else {
7576
TokenizeSingleWordImpl</*kGetPieces=*/true, /*kGetIds=*/true,
7677
/*kGetOffsets=*/true>(
@@ -86,9 +87,9 @@ void FastWordpieceTokenizer::Tokenize(absl::string_view input,
8687
int input_word_offset_in_text) const {
8788
if (config_->end_to_end()) {
8889
TokenizeTextImpl</*kGetPieces=*/false, /*kGetIds=*/true,
89-
/*kGetOffsets=*/true>(input, /*output_pieces=*/nullptr,
90-
output_ids, output_start_offsets,
91-
output_end_offsets);
90+
/*kGetOffsets=*/true>(
91+
input, /*output_pieces=*/nullptr, output_ids, output_start_offsets,
92+
output_end_offsets, /*error=*/nullptr);
9293
} else {
9394
TokenizeSingleWordImpl</*kGetPieces=*/false, /*kGetIds=*/true,
9495
/*kGetOffsets=*/true>(
@@ -102,10 +103,10 @@ void FastWordpieceTokenizer::Tokenize(absl::string_view input,
102103
int input_word_offset_in_text) const {
103104
if (config_->end_to_end()) {
104105
TokenizeTextImpl</*kGetPieces=*/false, /*kGetIds=*/true,
105-
/*kGetOffsets=*/false>(input, /*output_pieces=*/nullptr,
106-
output_ids,
107-
/*output_start_offsets=*/nullptr,
108-
/*output_end_offsets=*/nullptr);
106+
/*kGetOffsets=*/false>(
107+
input, /*output_pieces=*/nullptr, output_ids,
108+
/*output_start_offsets=*/nullptr,
109+
/*output_end_offsets=*/nullptr, /*error=*/nullptr);
109110
} else {
110111
TokenizeSingleWordImpl</*kGetPieces=*/false, /*kGetIds=*/true,
111112
/*kGetOffsets=*/false>(
@@ -186,20 +187,28 @@ template <bool kGetPieces, bool kGetIds, bool kGetOffsets>
186187
void FastWordpieceTokenizer::TokenizeTextImpl(
187188
absl::string_view input_text, std::vector<std::string>* output_pieces,
188189
std::vector<int>* output_ids, std::vector<int>* output_start_offsets,
189-
std::vector<int>* output_end_offsets) const {
190+
std::vector<int>* output_end_offsets, bool* error) const {
190191
static_assert(kGetPieces || kGetIds,
191192
"At least one of `kGetPieces` and `kGetIds` should be true.");
192193
if (input_text.empty()) {
193194
return;
194195
}
195196
const int input_size = input_text.size();
197+
int prev_pos = -1;
196198
int next_pos = 0;
197199
int cur_pos = 0;
198200
int original_num_tokens =
199201
GetCurrentOutputSize<kGetPieces>(output_pieces, output_ids);
200202
UChar32 prev_unicode_char;
201203
UChar32 cur_unicode_char;
202204
while (cur_pos < input_size) {
205+
// Prevent looping without progress in cur_pos.
206+
if (prev_pos == cur_pos && error != nullptr) {
207+
*error = true;
208+
return;
209+
}
210+
prev_pos = cur_pos;
211+
203212
int cur_offset_in_input_word = 0;
204213
// Tokenize the word starting at the current position.
205214
auto cur_node = trie_->CreateTraversalCursorPointToRoot();
@@ -210,7 +219,15 @@ void FastWordpieceTokenizer::TokenizeTextImpl(
210219
// 1. it steps over the input boundary, or
211220
// 2. the length of the current word reaches 'max_bytes_per_token', or
212221
// 3. it sees a whitespace / punctuation / unknown character.
222+
int prev_pos_inner = -1;
213223
while (cur_pos < input_size) {
224+
// Prevent looping without progress in cur_pos.
225+
if (prev_pos_inner == cur_pos && error != nullptr) {
226+
*error = true;
227+
return;
228+
}
229+
prev_pos_inner = cur_pos;
230+
214231
prev_unicode_char = cur_unicode_char;
215232
next_pos = cur_pos;
216233
U8_NEXT(input_text, next_pos, input_text.length(), cur_unicode_char);

tensorflow_text/core/kernels/fast_wordpiece_tokenizer.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,13 +72,15 @@ class FastWordpieceTokenizer {
7272
// text, in utf-8 bytes.
7373
// * input_word_offset_in_text: The relative offset of the input word in
7474
// the whole text. Only used when not using end-to-end tokenizer.
75+
// * error: If not null, this will be set to true if the tokenizer failed to
76+
// make progress in decoding the input.
7577
// Note: the start offsets are inclusive and the end offsets are exclusive.
7678
void Tokenize(absl::string_view input,
7779
std::vector<std::string>* output_pieces,
7880
std::vector<int>* output_ids,
7981
std::vector<int>* output_start_offsets,
8082
std::vector<int>* output_end_offsets,
81-
int input_word_offset_in_text = 0) const;
83+
int input_word_offset_in_text = 0, bool* error = nullptr) const;
8284

8385
// An override not returning `output_pieces`.
8486
void Tokenize(absl::string_view input, std::vector<int>* output_ids,
@@ -125,7 +127,8 @@ class FastWordpieceTokenizer {
125127
std::vector<std::string>* output_pieces,
126128
std::vector<int>* output_ids,
127129
std::vector<int>* output_start_offsets,
128-
std::vector<int>* output_end_offsets) const;
130+
std::vector<int>* output_end_offsets,
131+
bool* error) const;
129132

130133
// Try following the failure link to make the transition when trie matching
131134
// fails.

tensorflow_text/core/kernels/fast_wordpiece_tokenizer_kernel_template.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,14 @@ absl::Status FastWordpieceTokenizeWithOffsetsOp<Rt>::Invoke(
163163
for (int i = 0; i < values_vec.Dim(0); ++i) {
164164
// Tokenize into subwords and record the offset locations.
165165
const int original_num_wordpieces = subwords.size();
166+
bool error = false;
166167
fast_wordpiece_tokenizer->Tokenize(values_vec(i), &subwords, &subword_ids,
167-
&begin_offset, &end_offset);
168+
&begin_offset, &end_offset,
169+
/*input_word_offset_in_text=*/0, &error);
170+
if (error) {
171+
return absl::InternalError(
172+
"Failed to make any progress in tokenizing the input text.");
173+
}
168174
const int delta_num_wordpieces = subwords.size() - original_num_wordpieces;
169175

170176
// Record the row splits.

0 commit comments

Comments
 (0)