Skip to content

Commit 026c4d1

Browse files
committed
[Feature] Add UTF-8 support for ngram_search functions
This patch adds proper UTF-8 character-based n-gram computation for ngram_search and ngram_search_case_insensitive functions. Previously, n-grams were computed byte-by-byte, which produced incorrect results for non-ASCII text (Cyrillic, Chinese, etc.). Now n-grams are computed based on UTF-8 characters using the existing UTF8_BYTE_LENGTH_TABLE. Changes: - Add session variable `ngram_search_support_utf8` (default: false) - Add utf8_tolower() utility function using ICU for proper Unicode case folding - Fix ngram_search to iterate by UTF-8 characters when enabled - Fix bloom filter index writer to use UTF-8 case folding - Fix bloom filter query to use UTF-8 case folding Note: Bloom filter index already used UTF-8 for n-gram extraction, but case-insensitive mode used ASCII tolower. This is now fixed.
1 parent 9a6b18c commit 026c4d1

File tree

10 files changed

+299
-54
lines changed

10 files changed

+299
-54
lines changed

be/src/exprs/function_call_expr.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,8 +246,9 @@ bool VectorizedFunctionCallExpr::ngram_bloom_filter(ExprContext* context, const
246246

247247
// for case_insensitive, we need to convert needle to lower case
248248
if (!reader_options.index_case_sensitive) {
249-
std::transform(needle.begin(), needle.end(), needle.begin(),
250-
[](unsigned char c) { return std::tolower(c); });
249+
std::string lower_needle;
250+
utf8_tolower(needle, lower_needle);
251+
needle = std::move(lower_needle);
251252
}
252253

253254
if (!simdjson::validate_utf8(needle.data(), needle.size())) {

be/src/exprs/ngram.cpp

Lines changed: 162 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
#include "exprs/function_context.h"
1919
#include "exprs/string_functions.h"
2020
#include "gutil/strings/fastmem.h"
21+
#include "runtime/runtime_state.h"
22+
#include "util/utf8.h"
23+
2124
namespace starrocks {
2225
static constexpr size_t MAX_STRING_SIZE = 1 << 15;
2326
// uint16[2^16] can almost fit into L2
@@ -41,6 +44,9 @@ struct Ngramstate {
4144

4245
float result = -1;
4346

47+
// Flag to indicate whether UTF-8 mode is enabled (set in prepare from template parameter)
48+
bool use_utf8 = false;
49+
4450
std::vector<NgramHash>* get_or_create_driver_hashmap() {
4551
std::thread::id current_thread_id = std::this_thread::get_id();
4652

@@ -85,12 +91,21 @@ class NgramFunctionImpl {
8591
return Status::NotSupported("ngram search's third parameter must be a positive number");
8692
}
8793

94+
auto state = reinterpret_cast<Ngramstate*>(context->get_function_state(FunctionContext::FRAGMENT_LOCAL));
95+
96+
// For UTF-8 mode, check character count instead of byte count
97+
size_t needle_char_count;
98+
if constexpr (use_utf_8) {
99+
needle_char_count = utf8_len(needle.get_data(), needle.get_data() + needle.get_size());
100+
} else {
101+
needle_char_count = needle.get_size();
102+
}
103+
88104
// needle is too small so we can not get even single Ngram, so they are not similar at all
89-
if (needle.get_size() < gram_num) {
105+
if (needle_char_count < gram_num) {
90106
return ColumnHelper::create_const_column<TYPE_DOUBLE>(0, haystack_column->size());
91107
}
92108

93-
auto state = reinterpret_cast<Ngramstate*>(context->get_function_state(FunctionContext::FRAGMENT_LOCAL));
94109
std::vector<NgramHash>* map = state->get_or_create_driver_hashmap();
95110
if (haystack_column->is_constant()) {
96111
if (context->is_constant_column(0)) {
@@ -119,6 +134,7 @@ class NgramFunctionImpl {
119134
}
120135

121136
auto* state = new Ngramstate(MAP_SIZE);
137+
state->use_utf8 = use_utf_8;
122138

123139
context->set_function_state(scope, state);
124140

@@ -148,26 +164,67 @@ class NgramFunctionImpl {
148164
}
149165

150166
private:
167+
// Get UTF-8 character positions for a string
168+
static void get_utf8_positions(const char* data, size_t len, std::vector<size_t>& positions) {
169+
positions.clear();
170+
for (size_t i = 0; i < len;) {
171+
positions.push_back(i);
172+
i += UTF8_BYTE_LENGTH_TABLE[static_cast<uint8_t>(data[i])];
173+
}
174+
}
175+
176+
// UTF-8 aware tolower - uses shared implementation from util/utf8.h
177+
static void tolower_utf8(const Slice& str, std::string& buf) {
178+
utf8_tolower(str.get_data(), str.get_size(), buf);
179+
}
180+
151181
// for every gram of needle, we calculate its' hash value and store its' frequency in map, and return the number of gram in needle
152182
size_t static calculateMapWithNeedle(std::vector<NgramHash>& map, const Slice& needle, size_t gram_num) {
153-
size_t needle_length = needle.get_size();
154-
NgramHash cur_hash;
155-
size_t i;
156-
Slice cur_needle(needle.get_data(), needle_length);
157-
const char* cur_char_ptr;
183+
Slice cur_needle(needle.get_data(), needle.get_size());
158184
std::string buf;
159185
if constexpr (case_insensitive) {
160-
tolower(needle, buf);
186+
if constexpr (use_utf_8) {
187+
tolower_utf8(needle, buf);
188+
} else {
189+
buf.assign(needle.get_data(), needle.get_size());
190+
std::transform(buf.begin(), buf.end(), buf.begin(), [](unsigned char c) { return std::tolower(c); });
191+
}
161192
cur_needle = Slice(buf.c_str(), buf.size());
162193
}
163-
cur_char_ptr = cur_needle.get_data();
164194

165-
for (i = 0; i + gram_num <= needle_length; i++) {
166-
cur_hash = getAsciiHash(cur_char_ptr + i, gram_num);
167-
map[cur_hash]++;
168-
}
195+
const char* data = cur_needle.get_data();
196+
size_t len = cur_needle.get_size();
197+
198+
if constexpr (use_utf_8) {
199+
// UTF-8 mode: iterate by characters
200+
std::vector<size_t> positions;
201+
get_utf8_positions(data, len, positions);
202+
203+
size_t num_chars = positions.size();
204+
if (num_chars < gram_num) {
205+
return 0;
206+
}
169207

170-
return i;
208+
size_t gram_count = 0;
209+
for (size_t i = 0; i + gram_num <= num_chars; i++) {
210+
size_t start = positions[i];
211+
size_t end = (i + gram_num < num_chars) ? positions[i + gram_num] : len;
212+
size_t ngram_bytes = end - start;
213+
214+
NgramHash cur_hash = crc_hash_32(data + start, ngram_bytes, CRC_HASH_SEEDS::CRC_HASH_SEED1) & (0xffffu);
215+
map[cur_hash]++;
216+
gram_count++;
217+
}
218+
return gram_count;
219+
} else {
220+
// ASCII mode: iterate by bytes (original behavior)
221+
size_t i;
222+
for (i = 0; i + gram_num <= len; i++) {
223+
NgramHash cur_hash = crc_hash_32(data + i, gram_num, CRC_HASH_SEEDS::CRC_HASH_SEED1) & (0xffffu);
224+
map[cur_hash]++;
225+
}
226+
return i;
227+
}
171228
}
172229

173230
ColumnPtr static haystack_vector_and_needle_const(const ColumnPtr& haystack_column, std::vector<NgramHash>& map,
@@ -176,17 +233,17 @@ class NgramFunctionImpl {
176233

177234
NullColumnPtr res_null = nullptr;
178235
ColumnPtr haystackPtr = nullptr;
179-
// used in case_insensitive
180-
StatusOr<ColumnPtr> lower;
181236
if (haystack_column->is_nullable()) {
182237
auto haystack_nullable = ColumnHelper::as_column<NullableColumn>(haystack_column);
183238
res_null = haystack_nullable->null_column();
184239
haystackPtr = haystack_nullable->data_column();
185240
} else {
186241
haystackPtr = haystack_column;
187242
}
188-
if constexpr (case_insensitive) {
189-
// @TODO if ngram supports utf8 in the future, we should use antoher implementation.
243+
244+
// For case-insensitive ASCII mode, use the fast StringCaseToggleFunction
245+
// For UTF-8 mode, we handle case conversion per-string in calculateDistanceWithHaystack
246+
if constexpr (case_insensitive && !use_utf_8) {
190247
haystackPtr = StringCaseToggleFunction<false>::evaluate<TYPE_VARCHAR, TYPE_VARCHAR>(haystackPtr);
191248
}
192249

@@ -235,7 +292,12 @@ class NgramFunctionImpl {
235292

236293
std::string buf;
237294
if constexpr (case_insensitive) {
238-
tolower(haystack, buf);
295+
if constexpr (use_utf_8) {
296+
tolower_utf8(haystack, buf);
297+
} else {
298+
buf.assign(haystack.get_data(), haystack.get_size());
299+
std::transform(buf.begin(), buf.end(), buf.begin(), [](unsigned char c) { return std::tolower(c); });
300+
}
239301
cur_haystack = Slice(buf.c_str(), buf.size());
240302
}
241303

@@ -250,67 +312,117 @@ class NgramFunctionImpl {
250312
return result;
251313
}
252314

253-
// traverse haystacks every gram, find whether this gram is in needle or not using gram's hash
315+
// traverse haystack's every gram, find whether this gram is in needle or not using gram's hash
254316
// 16bit hash value may cause hash collision, but because we just calculate the similarity of two string
255317
// so don't need to be so accurate.
256318
template <bool need_recovery_map>
257319
size_t static calculateDistanceWithHaystack(std::vector<NgramHash>& map, const Slice& haystack,
258320
[[maybe_unused]] std::vector<NgramHash>& map_restore_helper,
259321
size_t needle_gram_count, size_t gram_num) {
260-
size_t haystack_length = haystack.get_size();
261-
NgramHash cur_hash;
262-
size_t i;
263-
const char* ptr = haystack.get_data();
264-
265-
for (i = 0; i + gram_num <= haystack_length; i++) {
266-
cur_hash = getAsciiHash(ptr + i, gram_num);
267-
// if this gram is in needle
268-
if (map[cur_hash] > 0) {
269-
needle_gram_count--;
270-
map[cur_hash]--;
271-
if constexpr (need_recovery_map) {
272-
map_restore_helper[i] = cur_hash;
322+
// For UTF-8 case-insensitive mode in vector processing, we need to convert here
323+
std::string lower_buf;
324+
Slice cur_haystack = haystack;
325+
if constexpr (case_insensitive && use_utf_8) {
326+
tolower_utf8(haystack, lower_buf);
327+
cur_haystack = Slice(lower_buf.c_str(), lower_buf.size());
328+
}
329+
330+
const char* data = cur_haystack.get_data();
331+
size_t len = cur_haystack.get_size();
332+
333+
if constexpr (use_utf_8) {
334+
// UTF-8 mode: iterate by characters
335+
std::vector<size_t> positions;
336+
get_utf8_positions(data, len, positions);
337+
338+
size_t num_chars = positions.size();
339+
if (num_chars < gram_num) {
340+
return needle_gram_count;
341+
}
342+
343+
// For UTF-8 mode, we use positions as indices in map_restore_helper
344+
size_t gram_idx = 0;
345+
for (size_t i = 0; i + gram_num <= num_chars; i++, gram_idx++) {
346+
size_t start = positions[i];
347+
size_t end = (i + gram_num < num_chars) ? positions[i + gram_num] : len;
348+
size_t ngram_bytes = end - start;
349+
350+
NgramHash cur_hash = crc_hash_32(data + start, ngram_bytes, CRC_HASH_SEEDS::CRC_HASH_SEED1) & (0xffffu);
351+
352+
if (map[cur_hash] > 0) {
353+
needle_gram_count--;
354+
map[cur_hash]--;
355+
if constexpr (need_recovery_map) {
356+
map_restore_helper[gram_idx] = cur_hash;
357+
}
358+
}
359+
}
360+
361+
if constexpr (need_recovery_map) {
362+
for (size_t j = 0; j < gram_idx; j++) {
363+
if (map_restore_helper[j]) {
364+
map[map_restore_helper[j]]++;
365+
map_restore_helper[j] = 0;
366+
}
367+
}
368+
}
369+
} else {
370+
// ASCII mode: iterate by bytes (original behavior)
371+
size_t i;
372+
for (i = 0; i + gram_num <= len; i++) {
373+
NgramHash cur_hash = crc_hash_32(data + i, gram_num, CRC_HASH_SEEDS::CRC_HASH_SEED1) & (0xffffu);
374+
if (map[cur_hash] > 0) {
375+
needle_gram_count--;
376+
map[cur_hash]--;
377+
if constexpr (need_recovery_map) {
378+
map_restore_helper[i] = cur_hash;
379+
}
273380
}
274381
}
275-
}
276382

277-
if constexpr (need_recovery_map) {
278-
for (int j = 0; j < i; j++) {
279-
if (map_restore_helper[j]) {
280-
map[map_restore_helper[j]]++;
281-
// reset map_restore_helper
282-
map_restore_helper[j] = 0;
383+
if constexpr (need_recovery_map) {
384+
for (size_t j = 0; j < i; j++) {
385+
if (map_restore_helper[j]) {
386+
map[map_restore_helper[j]]++;
387+
map_restore_helper[j] = 0;
388+
}
283389
}
284390
}
285391
}
286392

287393
return needle_gram_count;
288394
}
289-
290-
void inline static tolower(const Slice& str, std::string& buf) {
291-
buf.assign(str.get_data(), str.get_size());
292-
std::transform(buf.begin(), buf.end(), buf.begin(), [](unsigned char c) { return std::tolower(c); });
293-
}
294-
295-
static NgramHash getAsciiHash(const Gram* ch, size_t gram_num) {
296-
return crc_hash_32(ch, gram_num, CRC_HASH_SEEDS::CRC_HASH_SEED1) & (0xffffu);
297-
}
298395
};
299396

397+
// Wrapper functions that check the UTF-8 flag at runtime and dispatch to the correct implementation
300398
StatusOr<ColumnPtr> StringFunctions::ngram_search(FunctionContext* context, const Columns& columns) {
399+
auto state = reinterpret_cast<Ngramstate*>(context->get_function_state(FunctionContext::FRAGMENT_LOCAL));
400+
if (state && state->use_utf8) {
401+
return NgramFunctionImpl<false, true, char>::ngram_search_impl(context, columns);
402+
}
301403
return NgramFunctionImpl<false, false, char>::ngram_search_impl(context, columns);
302404
}
303405

304406
StatusOr<ColumnPtr> StringFunctions::ngram_search_case_insensitive(FunctionContext* context, const Columns& columns) {
407+
auto state = reinterpret_cast<Ngramstate*>(context->get_function_state(FunctionContext::FRAGMENT_LOCAL));
408+
if (state && state->use_utf8) {
409+
return NgramFunctionImpl<true, true, char>::ngram_search_impl(context, columns);
410+
}
305411
return NgramFunctionImpl<true, false, char>::ngram_search_impl(context, columns);
306412
}
307413

308414
Status StringFunctions::ngram_search_prepare(FunctionContext* context, FunctionContext::FunctionStateScope scope) {
415+
if (context->state() && context->state()->ngram_search_support_utf8()) {
416+
return NgramFunctionImpl<false, true, char>::ngram_search_prepare_impl(context, scope);
417+
}
309418
return NgramFunctionImpl<false, false, char>::ngram_search_prepare_impl(context, scope);
310419
}
311420

312421
Status StringFunctions::ngram_search_case_insensitive_prepare(FunctionContext* context,
313422
FunctionContext::FunctionStateScope scope) {
423+
if (context->state() && context->state()->ngram_search_support_utf8()) {
424+
return NgramFunctionImpl<true, true, char>::ngram_search_prepare_impl(context, scope);
425+
}
314426
return NgramFunctionImpl<true, false, char>::ngram_search_prepare_impl(context, scope);
315427
}
316428

@@ -325,4 +437,4 @@ Status StringFunctions::ngram_search_close(FunctionContext* context, FunctionCon
325437
return Status::OK();
326438
}
327439

328-
} // namespace starrocks
440+
} // namespace starrocks

be/src/runtime/runtime_state.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,10 @@ class RuntimeState {
579579
return _query_options.__isset.lower_upper_support_utf8 && _query_options.lower_upper_support_utf8;
580580
}
581581

582+
bool ngram_search_support_utf8() const {
583+
return _query_options.__isset.ngram_search_support_utf8 && _query_options.ngram_search_support_utf8;
584+
}
585+
582586
bool enable_global_late_materialization() const {
583587
return _query_options.__isset.enable_global_late_materialization &&
584588
_query_options.enable_global_late_materialization;

be/src/storage/rowset/bloom_filter_index_writer.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,9 +226,11 @@ class NgramBloomFilterIndexWriterImpl<field_type, std::enable_if_t<is_slice_type
226226
if (this->_bf_options.case_sensitive) {
227227
_values.insert(get_value<field_type>(&cur_ngram, this->_typeinfo, &this->_pool));
228228
} else {
229-
// todo::exist two copy of ngram, need to optimize
229+
// TODO: exist two copies of ngram, need to optimize
230+
// Use UTF-8 aware tolower for proper Unicode case folding
230231
std::string lower_ngram;
231-
Slice lower_ngram_slice = cur_ngram.tolower(lower_ngram);
232+
utf8_tolower(cur_ngram.get_data(), cur_ngram.get_size(), lower_ngram);
233+
Slice lower_ngram_slice(lower_ngram.data(), lower_ngram.size());
232234
_values.insert(get_value<field_type>(&lower_ngram_slice, this->_typeinfo, &this->_pool));
233235
}
234236
}

be/src/util/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ set(UTIL_FILES
8787
sm3.cpp
8888
frame_of_reference_coding.cpp
8989
utf8_check.cpp
90+
utf8.cpp
9091
path_util.cpp
9192
monotime.cpp
9293
thread.cpp

0 commit comments

Comments
 (0)