@@ -65,12 +65,13 @@ void FastWordpieceTokenizer::Tokenize(absl::string_view input,
6565 std::vector<int >* output_ids,
6666 std::vector<int >* output_start_offsets,
6767 std::vector<int >* output_end_offsets,
68- int input_word_offset_in_text) const {
68+ int input_word_offset_in_text,
69+ bool * error) const {
6970 if (config_->end_to_end ()) {
7071 TokenizeTextImpl</* kGetPieces=*/ true , /* kGetIds=*/ true ,
7172 /* kGetOffsets=*/ true >(input, output_pieces, output_ids,
7273 output_start_offsets,
73- output_end_offsets);
74+ output_end_offsets, error );
7475 } else {
7576 TokenizeSingleWordImpl</* kGetPieces=*/ true , /* kGetIds=*/ true ,
7677 /* kGetOffsets=*/ true >(
@@ -86,9 +87,9 @@ void FastWordpieceTokenizer::Tokenize(absl::string_view input,
8687 int input_word_offset_in_text) const {
8788 if (config_->end_to_end ()) {
8889 TokenizeTextImpl</* kGetPieces=*/ false , /* kGetIds=*/ true ,
89- /* kGetOffsets=*/ true >(input, /* output_pieces= */ nullptr ,
90- output_ids, output_start_offsets,
91- output_end_offsets);
90+ /* kGetOffsets=*/ true >(
91+ input, /* output_pieces= */ nullptr , output_ids, output_start_offsets,
92+ output_end_offsets, /* error= */ nullptr );
9293 } else {
9394 TokenizeSingleWordImpl</* kGetPieces=*/ false , /* kGetIds=*/ true ,
9495 /* kGetOffsets=*/ true >(
@@ -102,10 +103,10 @@ void FastWordpieceTokenizer::Tokenize(absl::string_view input,
102103 int input_word_offset_in_text) const {
103104 if (config_->end_to_end ()) {
104105 TokenizeTextImpl</* kGetPieces=*/ false , /* kGetIds=*/ true ,
105- /* kGetOffsets=*/ false >(input, /* output_pieces= */ nullptr ,
106- output_ids,
107- /* output_start_offsets=*/ nullptr ,
108- /* output_end_offsets=*/ nullptr );
106+ /* kGetOffsets=*/ false >(
107+ input, /* output_pieces= */ nullptr , output_ids,
108+ /* output_start_offsets=*/ nullptr ,
109+ /* output_end_offsets= */ nullptr , /* error =*/ nullptr );
109110 } else {
110111 TokenizeSingleWordImpl</* kGetPieces=*/ false , /* kGetIds=*/ true ,
111112 /* kGetOffsets=*/ false >(
@@ -186,20 +187,28 @@ template <bool kGetPieces, bool kGetIds, bool kGetOffsets>
186187void FastWordpieceTokenizer::TokenizeTextImpl (
187188 absl::string_view input_text, std::vector<std::string>* output_pieces,
188189 std::vector<int >* output_ids, std::vector<int >* output_start_offsets,
189- std::vector<int >* output_end_offsets) const {
190+ std::vector<int >* output_end_offsets, bool * error ) const {
190191 static_assert (kGetPieces || kGetIds ,
191192 " At least one of `kGetPieces` and `kGetIds` should be true." );
192193 if (input_text.empty ()) {
193194 return ;
194195 }
195196 const int input_size = input_text.size ();
197+ int prev_pos = -1 ;
196198 int next_pos = 0 ;
197199 int cur_pos = 0 ;
198200 int original_num_tokens =
199201 GetCurrentOutputSize<kGetPieces >(output_pieces, output_ids);
200202 UChar32 prev_unicode_char;
201203 UChar32 cur_unicode_char;
202204 while (cur_pos < input_size) {
205+ // Prevent looping without progress in cur_pos.
206+ if (prev_pos == cur_pos && error != nullptr ) {
207+ *error = true ;
208+ return ;
209+ }
210+ prev_pos = cur_pos;
211+
203212 int cur_offset_in_input_word = 0 ;
204213 // Tokenize the word starting at the current position.
205214 auto cur_node = trie_->CreateTraversalCursorPointToRoot ();
@@ -210,7 +219,15 @@ void FastWordpieceTokenizer::TokenizeTextImpl(
210219 // 1. it steps over the input boundary, or
211220 // 2. the length of the current word reaches 'max_bytes_per_token', or
212221 // 3. it sees a whitespace / punctuation / unknown character.
222+ int prev_pos_inner = -1 ;
213223 while (cur_pos < input_size) {
224+ // Prevent looping without progress in cur_pos.
225+ if (prev_pos_inner == cur_pos && error != nullptr ) {
226+ *error = true ;
227+ return ;
228+ }
229+ prev_pos_inner = cur_pos;
230+
214231 prev_unicode_char = cur_unicode_char;
215232 next_pos = cur_pos;
216233 U8_NEXT (input_text, next_pos, input_text.length (), cur_unicode_char);
0 commit comments