Skip to content

Commit 41599f1

Browse files
Mitigates a critical Out-of-Bounds (OOB) read vulnerability by transitioning internal trie references to absl::Span and validating arra...
PiperOrigin-RevId: 914999991
1 parent eb6b2c0 commit 41599f1

8 files changed

Lines changed: 90 additions & 22 deletions

tensorflow_text/core/kernels/BUILD

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ tf_cc_library(
215215
"@com_google_absl//absl/memory",
216216
"@com_google_absl//absl/status",
217217
"@com_google_absl//absl/strings",
218+
"@com_google_absl//absl/types:span",
218219
"@icu//:common",
219220
# lite/kernels/shim:status_macros tensorflow dep,
220221
],
@@ -351,7 +352,10 @@ cc_library(
351352
"darts_clone_trie_wrapper.h",
352353
],
353354
deps = [
355+
"@com_google_absl//absl/base:core_headers",
356+
"@com_google_absl//absl/status",
354357
"@com_google_absl//absl/status:statusor",
358+
"@com_google_absl//absl/types:span",
355359
],
356360
)
357361

@@ -363,6 +367,7 @@ cc_test(
363367
":darts_clone_trie_builder",
364368
":darts_clone_trie_wrapper",
365369
"@com_google_absl//absl/status",
370+
"@com_google_absl//absl/types:span",
366371
"@com_google_googletest//:gtest_main",
367372
],
368373
)
@@ -399,6 +404,7 @@ tf_cc_library(
399404
"@com_google_absl//absl/status",
400405
"@com_google_absl//absl/status:statusor",
401406
"@com_google_absl//absl/strings",
407+
"@com_google_absl//absl/types:span",
402408
"@icu//:nfkc",
403409
# lite/kernels/shim:status_macros tensorflow dep,
404410
],

tensorflow_text/core/kernels/darts_clone_trie_test.cc

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include <gmock/gmock.h>
1616
#include <gtest/gtest.h>
17+
#include "absl/types/span.h"
1718
#include "tensorflow_text/core/kernels/darts_clone_trie_builder.h"
1819
#include "tensorflow_text/core/kernels/darts_clone_trie_wrapper.h"
1920

@@ -31,7 +32,7 @@ TEST(DartsCloneTrieTest, CreateCursorPointToRootAndTryTraverseOneStep) {
3132
ASSERT_OK_AND_ASSIGN(std::vector<uint32_t> trie_array,
3233
BuildDartsCloneTrie(vocab_tokens));
3334
ASSERT_OK_AND_ASSIGN(DartsCloneTrieWrapper trie,
34-
DartsCloneTrieWrapper::Create(trie_array.data()));
35+
DartsCloneTrieWrapper::Create(trie_array));
3536

3637
DartsCloneTrieWrapper::TraversalCursor cursor;
3738
int data;
@@ -56,7 +57,7 @@ TEST(DartsCloneTrieTest, CreateCursorAndTryTraverseSeveralSteps) {
5657
ASSERT_OK_AND_ASSIGN(std::vector<uint32_t> trie_array,
5758
BuildDartsCloneTrie(vocab_tokens));
5859
ASSERT_OK_AND_ASSIGN(DartsCloneTrieWrapper trie,
59-
DartsCloneTrieWrapper::Create(trie_array.data()));
60+
DartsCloneTrieWrapper::Create(trie_array));
6061

6162
DartsCloneTrieWrapper::TraversalCursor cursor;
6263
int data;
@@ -76,7 +77,7 @@ TEST(DartsCloneTrieTest, TraversePathNotExisted) {
7677
ASSERT_OK_AND_ASSIGN(std::vector<uint32_t> trie_array,
7778
BuildDartsCloneTrie(vocab_tokens));
7879
ASSERT_OK_AND_ASSIGN(DartsCloneTrieWrapper trie,
79-
DartsCloneTrieWrapper::Create(trie_array.data()));
80+
DartsCloneTrieWrapper::Create(trie_array));
8081

8182
DartsCloneTrieWrapper::TraversalCursor cursor;
8283

@@ -94,7 +95,7 @@ TEST(DartsCloneTrieTest, TraverseOnUtf8Path) {
9495
ASSERT_OK_AND_ASSIGN(std::vector<uint32_t> trie_array,
9596
BuildDartsCloneTrie(vocab_tokens));
9697
ASSERT_OK_AND_ASSIGN(DartsCloneTrieWrapper trie,
97-
DartsCloneTrieWrapper::Create(trie_array.data()));
98+
DartsCloneTrieWrapper::Create(trie_array));
9899

99100
DartsCloneTrieWrapper::TraversalCursor cursor;
100101
int data;
@@ -115,7 +116,7 @@ TEST(DartsCloneTrieTest, TraverseOnPartialUtf8Path) {
115116
ASSERT_OK_AND_ASSIGN(std::vector<uint32_t> trie_array,
116117
BuildDartsCloneTrie(vocab_tokens));
117118
ASSERT_OK_AND_ASSIGN(DartsCloneTrieWrapper trie,
118-
DartsCloneTrieWrapper::Create(trie_array.data()));
119+
DartsCloneTrieWrapper::Create(trie_array));
119120

120121
DartsCloneTrieWrapper::TraversalCursor cursor;
121122
int data;
@@ -135,7 +136,7 @@ TEST(DartsCloneTrieTest, TraverseOnUtf8PathNotExisted) {
135136
ASSERT_OK_AND_ASSIGN(std::vector<uint32_t> trie_array,
136137
BuildDartsCloneTrie(vocab_tokens));
137138
ASSERT_OK_AND_ASSIGN(DartsCloneTrieWrapper trie,
138-
DartsCloneTrieWrapper::Create(trie_array.data()));
139+
DartsCloneTrieWrapper::Create(trie_array));
139140

140141
DartsCloneTrieWrapper::TraversalCursor cursor;
141142

@@ -183,6 +184,32 @@ TEST(DartsCloneTrieBuildError, NegativeValues) {
183184
StatusIs(util::error::INVALID_ARGUMENT));
184185
}
185186

187+
TEST(DartsCloneTrieTest, OutOfBoundsAccessIsRejected) {
188+
std::vector<std::string> vocab_tokens{"def", "\xe1\xb8\x8aZZ", "Abc"};
189+
ASSERT_OK_AND_ASSIGN(std::vector<uint32_t> trie_array,
190+
BuildDartsCloneTrie(vocab_tokens));
191+
// Wrap using a constrained span to emulate an out-of-bounds access attempts.
192+
auto span = absl::MakeSpan(trie_array.data(), 1);
193+
ASSERT_OK_AND_ASSIGN(DartsCloneTrieWrapper trie,
194+
DartsCloneTrieWrapper::Create(span));
195+
196+
DartsCloneTrieWrapper::TraversalCursor cursor =
197+
trie.CreateTraversalCursorPointToRoot();
198+
EXPECT_FALSE(trie.TryTraverseOneStep(cursor, 'd'));
199+
}
200+
201+
TEST(DartsCloneTrieTest, LegacyRawPointerCreateWorks) {
202+
std::vector<std::string> vocab_tokens{"def", "\xe1\xb8\x8aZZ", "Abc"};
203+
ASSERT_OK_AND_ASSIGN(std::vector<uint32_t> trie_array,
204+
BuildDartsCloneTrie(vocab_tokens));
205+
ASSERT_OK_AND_ASSIGN(DartsCloneTrieWrapper trie,
206+
DartsCloneTrieWrapper::Create(trie_array.data()));
207+
208+
DartsCloneTrieWrapper::TraversalCursor cursor =
209+
trie.CreateTraversalCursorPointToRoot();
210+
EXPECT_TRUE(trie.TryTraverseOneStep(cursor, 'd'));
211+
}
212+
186213
} // namespace trie_utils
187214
} // namespace text
188215
} // namespace tensorflow

tensorflow_text/core/kernels/darts_clone_trie_wrapper.h

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,11 @@
3030
#include <stdint.h>
3131
#include <string.h>
3232

33+
#include <limits>
34+
35+
#include "absl/status/status.h"
3336
#include "absl/status/statusor.h"
37+
#include "absl/types/span.h"
3438

3539
namespace tensorflow {
3640
namespace text {
@@ -51,16 +55,27 @@ class DartsCloneTrieWrapper {
5155
uint32_t unit = 0;
5256
};
5357

54-
// Constructs an instance by passing in the pointer to the trie array data.
58+
// Constructs an instance by passing in the span of the trie array data.
5559
// The caller needs to make sure that 'trie_array' points to a valid structure
5660
// returned by darts_clone trie builder. The caller also needs to maintain the
5761
// availability of 'trie_array' throughout the lifetime of this instance.
62+
static absl::StatusOr<DartsCloneTrieWrapper> Create(
63+
absl::Span<const uint32_t> trie_array) {
64+
if (trie_array.empty() || trie_array.data() == nullptr) {
65+
return absl::InvalidArgumentError("trie_array is empty or nullptr.");
66+
}
67+
return DartsCloneTrieWrapper(trie_array);
68+
}
69+
70+
// Deprecated: Please use the absl::Span-based Create method instead.
71+
// This legacy constructor creates a wrapper without bounds verification.
5872
static absl::StatusOr<DartsCloneTrieWrapper> Create(
5973
const uint32_t* trie_array) {
6074
if (trie_array == nullptr) {
6175
return absl::InvalidArgumentError("trie_array is nullptr.");
6276
}
63-
return DartsCloneTrieWrapper(trie_array);
77+
return DartsCloneTrieWrapper(
78+
absl::MakeSpan(trie_array, std::numeric_limits<size_t>::max()));
6479
}
6580

6681
// Creates a cursor pointing to the root.
@@ -70,20 +85,28 @@ class DartsCloneTrieWrapper {
7085

7186
// Creates a cursor pointing to the 'node_id'.
7287
TraversalCursor CreateTraversalCursor(uint32_t node_id) {
88+
if (node_id >= trie_array_.size()) {
89+
return {0, 0};
90+
}
7391
return {node_id, trie_array_[node_id]};
7492
}
7593

7694
// Sets the cursor to point to 'node_id'.
7795
void SetTraversalCursor(TraversalCursor& cursor, uint32_t node_id) {
78-
cursor.node_id = node_id;
79-
cursor.unit = trie_array_[node_id];
96+
if (node_id < trie_array_.size()) {
97+
cursor.node_id = node_id;
98+
cursor.unit = trie_array_[node_id];
99+
}
80100
}
81101

82102
// Traverses one step from 'cursor' following 'ch'. If successful (i.e., there
83103
// exists such an edge), moves 'cursor' to the new node and returns true.
84104
// Otherwise, does nothing (i.e., 'cursor' is not changed) and returns false.
85105
bool TryTraverseOneStep(TraversalCursor& cursor, unsigned char ch) const {
86106
const uint32_t next_node_id = cursor.node_id ^ offset(cursor.unit) ^ ch;
107+
if (next_node_id >= trie_array_.size()) {
108+
return false;
109+
}
87110
const uint32_t next_node_unit = trie_array_[next_node_id];
88111
if (label(next_node_unit) != ch) {
89112
return false;
@@ -108,15 +131,18 @@ class DartsCloneTrieWrapper {
108131
if (!has_leaf(cursor.unit)) {
109132
return false;
110133
}
111-
const uint32_t value_unit =
112-
trie_array_[cursor.node_id ^ offset(cursor.unit)];
134+
const uint32_t value_node_id = cursor.node_id ^ offset(cursor.unit);
135+
if (value_node_id >= trie_array_.size()) {
136+
return false;
137+
}
138+
const uint32_t value_unit = trie_array_[value_node_id];
113139
out_data = value(value_unit);
114140
return true;
115141
}
116142

117143
private:
118144
// Use Create() instead of the constructor.
119-
explicit DartsCloneTrieWrapper(const uint32_t* trie_array)
145+
explicit DartsCloneTrieWrapper(absl::Span<const uint32_t> trie_array)
120146
: trie_array_(trie_array) {}
121147

122148
// The actual implementation of TryTraverseSeveralSteps.
@@ -127,6 +153,9 @@ class DartsCloneTrieWrapper {
127153
for (; size > 0; --size, ++ptr) {
128154
const unsigned char ch = static_cast<const unsigned char>(*ptr);
129155
cur_id ^= offset(cur_unit) ^ ch;
156+
if (cur_id >= trie_array_.size()) {
157+
return false;
158+
}
130159
cur_unit = trie_array_[cur_id];
131160
if (label(cur_unit) != ch) {
132161
return false;
@@ -157,8 +186,8 @@ class DartsCloneTrieWrapper {
157186
return static_cast<int>(unit & 0x7fffffff);
158187
}
159188

160-
// The pointer to the darts trie array.
161-
const uint32_t* trie_array_;
189+
// The dart trie array represented as a span for bounds awareness.
190+
absl::Span<const uint32_t> trie_array_;
162191
};
163192

164193
} // namespace trie_utils

tensorflow_text/core/kernels/fast_bert_normalizer.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "absl/base/optimization.h"
2323
#include "absl/status/status.h"
2424
#include "absl/strings/string_view.h"
25+
#include "absl/types/span.h"
2526
#include "icu4c/source/common/unicode/utf8.h"
2627
#include "tensorflow/lite/kernels/shim/status_macros.h"
2728
#include "tensorflow_text/core/kernels/darts_clone_trie_wrapper.h"
@@ -85,12 +86,13 @@ class FastBertNormalizer {
8586
// which is not owned by this instance and should be kept alive through the
8687
// lifetime of the instance.
8788
static absl::StatusOr<FastBertNormalizer> Create(
88-
const uint32_t* trie_data, int data_for_codepoint_zero,
89+
absl::Span<const uint32_t> trie_data, int data_for_codepoint_zero,
8990
const char* normalized_string_pool,
9091
size_t normalized_string_pool_size = static_cast<size_t>(-1)) {
91-
if (trie_data == nullptr || normalized_string_pool == nullptr) {
92+
if (trie_data.empty() || trie_data.data() == nullptr ||
93+
normalized_string_pool == nullptr) {
9294
return absl::InvalidArgumentError(
93-
"trie_data or normalized_string_pool is null");
95+
"trie_data or normalized_string_pool is null or empty");
9496
}
9597
FastBertNormalizer result;
9698
SH_ASSIGN_OR_RETURN(auto trie,
@@ -123,7 +125,9 @@ class FastBertNormalizer {
123125
"FastBertNormalizerModel or its required fields are null");
124126
}
125127
return Create(
126-
model->trie_array()->data(), model->data_for_codepoint_zero(),
128+
absl::MakeSpan(model->trie_array()->data(),
129+
model->trie_array()->size()),
130+
model->data_for_codepoint_zero(),
127131
reinterpret_cast<const char*>(model->normalized_string_pool()->data()),
128132
model->normalized_string_pool()->size());
129133
}

tensorflow_text/core/kernels/fast_bert_normalizer_model_builder.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ FastBertNormalizerFactory::FastBertNormalizerFactory(
229229
return;
230230
}
231231
auto char_set_recognizer_mapper = FastBertNormalizer::Create(
232-
trie_data_.data(), data_for_codepoint_zero_, mapped_value_pool_.data());
232+
trie_data_, data_for_codepoint_zero_, mapped_value_pool_.data());
233233
if (!char_set_recognizer_mapper.ok()) {
234234
// Should never happen since the same code must have passed the unit tests.
235235
LOG(ERROR) << "Unexpected error: Failed to initialize "

tensorflow_text/core/kernels/fast_wordpiece_tokenizer.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "absl/strings/str_cat.h"
2525
#include "absl/strings/str_join.h"
2626
#include "absl/strings/string_view.h"
27+
#include "absl/types/span.h"
2728
#include "icu4c/source/common/unicode/uchar.h"
2829
#include "icu4c/source/common/unicode/utf8.h"
2930
#include "tensorflow/lite/kernels/shim/status_macros.h"
@@ -56,7 +57,8 @@ FastWordpieceTokenizer::Create(const void* config_flatbuffer) {
5657
"FastWordpieceTokenizerConfig or its trie_array is null.");
5758
}
5859
auto trie_or = trie_utils::DartsCloneTrieWrapper::Create(
59-
tokenizer.config_->trie_array()->data());
60+
absl::MakeSpan(tokenizer.config_->trie_array()->data(),
61+
tokenizer.config_->trie_array()->size()));
6062
if (!trie_or.ok()) {
6163
return absl::InvalidArgumentError(
6264
"Failed to create DartsCloneTrieWrapper from "

tensorflow_text/core/kernels/fast_wordpiece_tokenizer_model_builder.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,7 @@ absl::Status FastWordpieceBuilder::ConstructTrie(
434434
trie_utils::BuildDartsCloneTrie(keys, values));
435435
SH_ASSIGN_OR_RETURN(
436436
trie_utils::DartsCloneTrieWrapper trie,
437-
trie_utils::DartsCloneTrieWrapper::Create(trie_array_.data()));
437+
trie_utils::DartsCloneTrieWrapper::Create(trie_array_));
438438
trie_.emplace(std::move(trie));
439439

440440
if (trie_array_.size() >
Binary file not shown.

0 commit comments

Comments
 (0)