Skip to content

Commit 96c4ee9

Browse files
Harden FastWordpieceTokenizer, FastBertNormalizer, and PhraseTokenizer with
bounds verification and initialization checks: * Prevent null pointer dereferences by validating required FlatBuffer fields. * Prevent OOB reads during detokenization by checking token IDs against vocab size. * Secure FastBertNormalizer against adversarial heap slicing by tracking string pool limits. PiperOrigin-RevId: 912596984
1 parent 01fb107 commit 96c4ee9

5 files changed

Lines changed: 77 additions & 2 deletions

File tree

tensorflow_text/core/kernels/BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ tf_cc_library(
211211
":darts_clone_trie_builder",
212212
":darts_clone_trie_wrapper",
213213
":fast_bert_normalizer_model",
214+
"@com_google_absl//absl/base:core_headers",
214215
"@com_google_absl//absl/memory",
215216
"@com_google_absl//absl/status",
216217
"@com_google_absl//absl/strings",
@@ -1254,9 +1255,11 @@ cc_library(
12541255
":string_vocab",
12551256
":whitespace_tokenizer",
12561257
":whitespace_tokenizer_config_builder",
1258+
"@com_google_absl//absl/base:core_headers",
12571259
"@com_google_absl//absl/container:flat_hash_map",
12581260
"@com_google_absl//absl/container:flat_hash_set",
12591261
"@com_google_absl//absl/random",
1262+
"@com_google_absl//absl/status",
12601263
"@com_google_absl//absl/status:statusor",
12611264
"@com_google_absl//absl/strings",
12621265
# lite/kernels/shim:status_macros tensorflow dep,

tensorflow_text/core/kernels/fast_bert_normalizer.h

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,12 @@
1515
#ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_FAST_BERT_NORMALIZER_H_
1616
#define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_FAST_BERT_NORMALIZER_H_
1717

18+
#include <cstddef>
1819
#include <cstdint>
1920
#include <vector>
2021

22+
#include "absl/base/optimization.h"
23+
#include "absl/status/status.h"
2124
#include "absl/strings/string_view.h"
2225
#include "icu4c/source/common/unicode/utf8.h"
2326
#include "tensorflow/lite/kernels/shim/status_macros.h"
@@ -83,7 +86,12 @@ class FastBertNormalizer {
8386
// lifetime of the instance.
8487
static absl::StatusOr<FastBertNormalizer> Create(
8588
const uint32_t* trie_data, int data_for_codepoint_zero,
86-
const char* normalized_string_pool) {
89+
const char* normalized_string_pool,
90+
size_t normalized_string_pool_size = static_cast<size_t>(-1)) {
91+
if (trie_data == nullptr || normalized_string_pool == nullptr) {
92+
return absl::InvalidArgumentError(
93+
"trie_data or normalized_string_pool is null");
94+
}
8795
FastBertNormalizer result;
8896
SH_ASSIGN_OR_RETURN(auto trie,
8997
trie_utils::DartsCloneTrieWrapper::Create(trie_data));
@@ -92,6 +100,7 @@ class FastBertNormalizer {
92100
result.data_for_codepoint_zero_ = data_for_codepoint_zero;
93101
result.normalized_string_pool_ =
94102
reinterpret_cast<const char*>(normalized_string_pool);
103+
result.normalized_string_pool_size_ = normalized_string_pool_size;
95104
return result;
96105
}
97106

@@ -103,11 +112,20 @@ class FastBertNormalizer {
103112
// through the lifetime of the instance.
104113
static absl::StatusOr<FastBertNormalizer> Create(
105114
const void* model_flatbuffer) {
115+
if (model_flatbuffer == nullptr) {
116+
return absl::InvalidArgumentError("model_flatbuffer is null");
117+
}
106118
// `GetFastBertNormalizerModel()` is autogenerated by flatbuffer.
107119
auto model = GetFastBertNormalizerModel(model_flatbuffer);
120+
if (model == nullptr || model->trie_array() == nullptr ||
121+
model->normalized_string_pool() == nullptr) {
122+
return absl::InvalidArgumentError(
123+
"FastBertNormalizerModel or its required fields are null");
124+
}
108125
return Create(
109126
model->trie_array()->data(), model->data_for_codepoint_zero(),
110-
reinterpret_cast<const char*>(model->normalized_string_pool()->data()));
127+
reinterpret_cast<const char*>(model->normalized_string_pool()->data()),
128+
model->normalized_string_pool()->size());
111129
}
112130

113131
// Normalizes the input based on config `lower_case_nfd_strip_accents`.
@@ -290,6 +308,12 @@ class FastBertNormalizer {
290308
}
291309
const int offset = (data & text_norm::kNormalizedStringOffsetMask) >>
292310
text_norm::kBitsToEncodeUtf8LengthOfNormalizedString;
311+
if (ABSL_PREDICT_FALSE(
312+
offset < 0 ||
313+
(normalized_string_pool_size_ != static_cast<size_t>(-1) &&
314+
offset + len > normalized_string_pool_size_))) {
315+
return "";
316+
}
293317
return absl::string_view(normalized_string_pool_ + offset, len);
294318
}
295319

@@ -331,6 +355,9 @@ class FastBertNormalizer {
331355
// The string pool of normalized strings. Each normalized string is a
332356
// substring denoted by (offset and length).
333357
const char* normalized_string_pool_;
358+
359+
// The size of normalized_string_pool_ if known, or -1.
360+
size_t normalized_string_pool_size_ = static_cast<size_t>(-1);
334361
};
335362

336363
} // namespace text

tensorflow_text/core/kernels/fast_wordpiece_tokenizer.cc

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717
#include <memory>
1818

1919
#include "absl/base/attributes.h"
20+
#include "absl/base/optimization.h"
2021
#include "absl/status/status.h"
2122
#include "absl/status/statusor.h"
2223
#include "absl/strings/match.h"
24+
#include "absl/strings/str_cat.h"
2325
#include "absl/strings/str_join.h"
2426
#include "absl/strings/string_view.h"
2527
#include "icu4c/source/common/unicode/uchar.h"
@@ -48,6 +50,11 @@ FastWordpieceTokenizer::Create(const void* config_flatbuffer) {
4850
FastWordpieceTokenizer tokenizer;
4951
// `GetFastWordpieceTokenizerConfig()` is autogenerated by flatbuffer.
5052
tokenizer.config_ = GetFastWordpieceTokenizerConfig(config_flatbuffer);
53+
if (tokenizer.config_ == nullptr ||
54+
tokenizer.config_->trie_array() == nullptr) {
55+
return absl::InvalidArgumentError(
56+
"FastWordpieceTokenizerConfig or its trie_array is null.");
57+
}
5158
auto trie_or = trie_utils::DartsCloneTrieWrapper::Create(
5259
tokenizer.config_->trie_array()->data());
5360
if (!trie_or.ok()) {
@@ -127,8 +134,23 @@ FastWordpieceTokenizer::DetokenizeToTokens(
127134
"true in the config flatbuffer. Please rebuild the model flatbuffer "
128135
"by setting support_detokenization=true.");
129136
}
137+
if (config_->vocab_array() == nullptr ||
138+
config_->vocab_is_suffix_array() == nullptr) {
139+
return absl::InternalError(
140+
"Missing vocab_array or vocab_is_suffix_array in config.");
141+
}
142+
const int vocab_size = config_->vocab_array()->size();
143+
const int is_suffix_size = config_->vocab_is_suffix_array()->size();
130144
for (int id : input) {
145+
if (ABSL_PREDICT_FALSE(id < 0 || id >= vocab_size ||
146+
id >= is_suffix_size)) {
147+
return absl::OutOfRangeError(
148+
absl::StrCat("Token ID out of bounds: ", id));
149+
}
131150
auto vocab = config_->vocab_array()->Get(id);
151+
if (ABSL_PREDICT_FALSE(vocab == nullptr)) {
152+
return absl::InternalError("Null vocab string in vocab_array.");
153+
}
132154
auto is_suffix = config_->vocab_is_suffix_array()->Get(id);
133155
if (!subwords.empty() && !is_suffix) {
134156
// When current subword is not a suffix token, it marks the start of a new
@@ -140,6 +162,9 @@ FastWordpieceTokenizer::DetokenizeToTokens(
140162
// Special case: when a suffix token e.g. "##a" appears at the start of the
141163
// input ids, we preserve the suffix_indicator.
142164
if (subwords.empty() && is_suffix) {
165+
if (ABSL_PREDICT_FALSE(config_->suffix_indicator() == nullptr)) {
166+
return absl::InternalError("Missing suffix_indicator in config.");
167+
}
143168
subwords.emplace_back(config_->suffix_indicator()->string_view());
144169
}
145170
subwords.emplace_back(vocab->string_view());

tensorflow_text/core/kernels/phrase_tokenizer.cc

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@
2020
#include <string>
2121
#include <vector>
2222

23+
#include "absl/base/optimization.h"
24+
#include "absl/status/status.h"
2325
#include "absl/strings/match.h"
26+
#include "absl/strings/str_cat.h"
2427
#include "absl/strings/str_join.h"
2528
#include "absl/strings/string_view.h"
2629
#include "tensorflow/lite/kernels/shim/status_macros.h"
@@ -34,6 +37,12 @@ namespace text {
3437
PhraseTokenizer tokenizer;
3538
// `GetPhraseTokenizerConfig()` is autogenerated by flatbuffer.
3639
tokenizer.phrase_config_ = GetPhraseTokenizerConfig(config_flatbuffer);
40+
if (tokenizer.phrase_config_ == nullptr ||
41+
tokenizer.phrase_config_->vocab_trie() == nullptr ||
42+
tokenizer.phrase_config_->whitespace_config() == nullptr) {
43+
return absl::InvalidArgumentError(
44+
"PhraseTokenizerConfig or required fields are null.");
45+
}
3746
tokenizer.trie_ = absl::make_unique<sentencepiece::DoubleArrayTrie>(
3847
tokenizer.phrase_config_->vocab_trie()->nodes());
3948
tokenizer.prob_ = static_cast<float>(tokenizer.phrase_config_->prob()) / 100;
@@ -174,8 +183,19 @@ absl::StatusOr<std::vector<std::string>> PhraseTokenizer::DetokenizeToTokens(
174183
"true in the config flatbuffer. Please rebuild the model flatbuffer "
175184
"by setting support_detokenization=true.");
176185
}
186+
if (phrase_config_->vocab_array() == nullptr) {
187+
return absl::InternalError("Missing vocab_array in config.");
188+
}
189+
const int vocab_size = phrase_config_->vocab_array()->size();
177190
for (int id : input) {
191+
if (ABSL_PREDICT_FALSE(id < 0 || id >= vocab_size)) {
192+
return absl::OutOfRangeError(
193+
absl::StrCat("Token ID out of bounds: ", id));
194+
}
178195
auto vocab = phrase_config_->vocab_array()->Get(id);
196+
if (ABSL_PREDICT_FALSE(vocab == nullptr)) {
197+
return absl::InternalError("Null vocab string in vocab_array.");
198+
}
179199
output_tokens.emplace_back(vocab->string_view());
180200
}
181201
return output_tokens;
Binary file not shown.

0 commit comments

Comments
 (0)