1818#include " exprs/function_context.h"
1919#include " exprs/string_functions.h"
2020#include " gutil/strings/fastmem.h"
21+ #include " runtime/runtime_state.h"
22+ #include " util/utf8.h"
23+
2124namespace starrocks {
2225static constexpr size_t MAX_STRING_SIZE = 1 << 15 ;
2326// uint16[2^16] can almost fit into L2
@@ -41,6 +44,9 @@ struct Ngramstate {
4144
4245 float result = -1 ;
4346
47+ // Flag to indicate whether UTF-8 mode is enabled (set in prepare from template parameter)
48+ bool use_utf8 = false ;
49+
4450 std::vector<NgramHash>* get_or_create_driver_hashmap () {
4551 std::thread::id current_thread_id = std::this_thread::get_id ();
4652
@@ -85,12 +91,21 @@ class NgramFunctionImpl {
8591 return Status::NotSupported (" ngram search's third parameter must be a positive number" );
8692 }
8793
94+ auto state = reinterpret_cast <Ngramstate*>(context->get_function_state (FunctionContext::FRAGMENT_LOCAL));
95+
96+ // For UTF-8 mode, check character count instead of byte count
97+ size_t needle_char_count;
98+ if constexpr (use_utf_8) {
99+ needle_char_count = utf8_len (needle.get_data (), needle.get_data () + needle.get_size ());
100+ } else {
101+ needle_char_count = needle.get_size ();
102+ }
103+
88104 // needle is too small so we can not get even single Ngram, so they are not similar at all
89- if (needle. get_size () < gram_num) {
105+ if (needle_char_count < gram_num) {
90106 return ColumnHelper::create_const_column<TYPE_DOUBLE>(0 , haystack_column->size ());
91107 }
92108
93- auto state = reinterpret_cast <Ngramstate*>(context->get_function_state (FunctionContext::FRAGMENT_LOCAL));
94109 std::vector<NgramHash>* map = state->get_or_create_driver_hashmap ();
95110 if (haystack_column->is_constant ()) {
96111 if (context->is_constant_column (0 )) {
@@ -119,6 +134,7 @@ class NgramFunctionImpl {
119134 }
120135
121136 auto * state = new Ngramstate (MAP_SIZE);
137+ state->use_utf8 = use_utf_8;
122138
123139 context->set_function_state (scope, state);
124140
@@ -148,26 +164,67 @@ class NgramFunctionImpl {
148164 }
149165
150166private:
167+ // Get UTF-8 character positions for a string
168+ static void get_utf8_positions (const char * data, size_t len, std::vector<size_t >& positions) {
169+ positions.clear ();
170+ for (size_t i = 0 ; i < len;) {
171+ positions.push_back (i);
172+ i += UTF8_BYTE_LENGTH_TABLE[static_cast <uint8_t >(data[i])];
173+ }
174+ }
175+
176+ // UTF-8 aware tolower - uses shared implementation from util/utf8.h
177+ static void tolower_utf8 (const Slice& str, std::string& buf) {
178+ utf8_tolower (str.get_data (), str.get_size (), buf);
179+ }
180+
151181 // for every gram of needle, we calculate its' hash value and store its' frequency in map, and return the number of gram in needle
152182 size_t static calculateMapWithNeedle (std::vector<NgramHash>& map, const Slice& needle, size_t gram_num) {
153- size_t needle_length = needle.get_size ();
154- NgramHash cur_hash;
155- size_t i;
156- Slice cur_needle (needle.get_data (), needle_length);
157- const char * cur_char_ptr;
183+ Slice cur_needle (needle.get_data (), needle.get_size ());
158184 std::string buf;
159185 if constexpr (case_insensitive) {
160- tolower (needle, buf);
186+ if constexpr (use_utf_8) {
187+ tolower_utf8 (needle, buf);
188+ } else {
189+ buf.assign (needle.get_data (), needle.get_size ());
190+ std::transform (buf.begin (), buf.end (), buf.begin (), [](unsigned char c) { return std::tolower (c); });
191+ }
161192 cur_needle = Slice (buf.c_str (), buf.size ());
162193 }
163- cur_char_ptr = cur_needle.get_data ();
164194
165- for (i = 0 ; i + gram_num <= needle_length; i++) {
166- cur_hash = getAsciiHash (cur_char_ptr + i, gram_num);
167- map[cur_hash]++;
168- }
195+ const char * data = cur_needle.get_data ();
196+ size_t len = cur_needle.get_size ();
197+
198+ if constexpr (use_utf_8) {
199+ // UTF-8 mode: iterate by characters
200+ std::vector<size_t > positions;
201+ get_utf8_positions (data, len, positions);
202+
203+ size_t num_chars = positions.size ();
204+ if (num_chars < gram_num) {
205+ return 0 ;
206+ }
169207
170- return i;
208+ size_t gram_count = 0 ;
209+ for (size_t i = 0 ; i + gram_num <= num_chars; i++) {
210+ size_t start = positions[i];
211+ size_t end = (i + gram_num < num_chars) ? positions[i + gram_num] : len;
212+ size_t ngram_bytes = end - start;
213+
214+ NgramHash cur_hash = crc_hash_32 (data + start, ngram_bytes, CRC_HASH_SEEDS::CRC_HASH_SEED1) & (0xffffu );
215+ map[cur_hash]++;
216+ gram_count++;
217+ }
218+ return gram_count;
219+ } else {
220+ // ASCII mode: iterate by bytes (original behavior)
221+ size_t i;
222+ for (i = 0 ; i + gram_num <= len; i++) {
223+ NgramHash cur_hash = crc_hash_32 (data + i, gram_num, CRC_HASH_SEEDS::CRC_HASH_SEED1) & (0xffffu );
224+ map[cur_hash]++;
225+ }
226+ return i;
227+ }
171228 }
172229
173230 ColumnPtr static haystack_vector_and_needle_const (const ColumnPtr& haystack_column, std::vector<NgramHash>& map,
@@ -176,17 +233,17 @@ class NgramFunctionImpl {
176233
177234 NullColumnPtr res_null = nullptr ;
178235 ColumnPtr haystackPtr = nullptr ;
179- // used in case_insensitive
180- StatusOr<ColumnPtr> lower;
181236 if (haystack_column->is_nullable ()) {
182237 auto haystack_nullable = ColumnHelper::as_column<NullableColumn>(haystack_column);
183238 res_null = haystack_nullable->null_column ();
184239 haystackPtr = haystack_nullable->data_column ();
185240 } else {
186241 haystackPtr = haystack_column;
187242 }
188- if constexpr (case_insensitive) {
189- // @TODO if ngram supports utf8 in the future, we should use antoher implementation.
243+
244+ // For case-insensitive ASCII mode, use the fast StringCaseToggleFunction
245+ // For UTF-8 mode, we handle case conversion per-string in calculateDistanceWithHaystack
246+ if constexpr (case_insensitive && !use_utf_8) {
190247 haystackPtr = StringCaseToggleFunction<false >::evaluate<TYPE_VARCHAR, TYPE_VARCHAR>(haystackPtr);
191248 }
192249
@@ -235,7 +292,12 @@ class NgramFunctionImpl {
235292
236293 std::string buf;
237294 if constexpr (case_insensitive) {
238- tolower (haystack, buf);
295+ if constexpr (use_utf_8) {
296+ tolower_utf8 (haystack, buf);
297+ } else {
298+ buf.assign (haystack.get_data (), haystack.get_size ());
299+ std::transform (buf.begin (), buf.end (), buf.begin (), [](unsigned char c) { return std::tolower (c); });
300+ }
239301 cur_haystack = Slice (buf.c_str (), buf.size ());
240302 }
241303
@@ -250,67 +312,117 @@ class NgramFunctionImpl {
250312 return result;
251313 }
252314
253- // traverse haystack‘ s every gram, find whether this gram is in needle or not using gram's hash
315+ // traverse haystack' s every gram, find whether this gram is in needle or not using gram's hash
254316 // 16bit hash value may cause hash collision, but because we just calculate the similarity of two string
255317 // so don't need to be so accurate.
256318 template <bool need_recovery_map>
257319 size_t static calculateDistanceWithHaystack (std::vector<NgramHash>& map, const Slice& haystack,
258320 [[maybe_unused]] std::vector<NgramHash>& map_restore_helper,
259321 size_t needle_gram_count, size_t gram_num) {
260- size_t haystack_length = haystack.get_size ();
261- NgramHash cur_hash;
262- size_t i;
263- const char * ptr = haystack.get_data ();
264-
265- for (i = 0 ; i + gram_num <= haystack_length; i++) {
266- cur_hash = getAsciiHash (ptr + i, gram_num);
267- // if this gram is in needle
268- if (map[cur_hash] > 0 ) {
269- needle_gram_count--;
270- map[cur_hash]--;
271- if constexpr (need_recovery_map) {
272- map_restore_helper[i] = cur_hash;
322+ // For UTF-8 case-insensitive mode in vector processing, we need to convert here
323+ std::string lower_buf;
324+ Slice cur_haystack = haystack;
325+ if constexpr (case_insensitive && use_utf_8) {
326+ tolower_utf8 (haystack, lower_buf);
327+ cur_haystack = Slice (lower_buf.c_str (), lower_buf.size ());
328+ }
329+
330+ const char * data = cur_haystack.get_data ();
331+ size_t len = cur_haystack.get_size ();
332+
333+ if constexpr (use_utf_8) {
334+ // UTF-8 mode: iterate by characters
335+ std::vector<size_t > positions;
336+ get_utf8_positions (data, len, positions);
337+
338+ size_t num_chars = positions.size ();
339+ if (num_chars < gram_num) {
340+ return needle_gram_count;
341+ }
342+
343+ // For UTF-8 mode, we use positions as indices in map_restore_helper
344+ size_t gram_idx = 0 ;
345+ for (size_t i = 0 ; i + gram_num <= num_chars; i++, gram_idx++) {
346+ size_t start = positions[i];
347+ size_t end = (i + gram_num < num_chars) ? positions[i + gram_num] : len;
348+ size_t ngram_bytes = end - start;
349+
350+ NgramHash cur_hash = crc_hash_32 (data + start, ngram_bytes, CRC_HASH_SEEDS::CRC_HASH_SEED1) & (0xffffu );
351+
352+ if (map[cur_hash] > 0 ) {
353+ needle_gram_count--;
354+ map[cur_hash]--;
355+ if constexpr (need_recovery_map) {
356+ map_restore_helper[gram_idx] = cur_hash;
357+ }
358+ }
359+ }
360+
361+ if constexpr (need_recovery_map) {
362+ for (size_t j = 0 ; j < gram_idx; j++) {
363+ if (map_restore_helper[j]) {
364+ map[map_restore_helper[j]]++;
365+ map_restore_helper[j] = 0 ;
366+ }
367+ }
368+ }
369+ } else {
370+ // ASCII mode: iterate by bytes (original behavior)
371+ size_t i;
372+ for (i = 0 ; i + gram_num <= len; i++) {
373+ NgramHash cur_hash = crc_hash_32 (data + i, gram_num, CRC_HASH_SEEDS::CRC_HASH_SEED1) & (0xffffu );
374+ if (map[cur_hash] > 0 ) {
375+ needle_gram_count--;
376+ map[cur_hash]--;
377+ if constexpr (need_recovery_map) {
378+ map_restore_helper[i] = cur_hash;
379+ }
273380 }
274381 }
275- }
276382
277- if constexpr (need_recovery_map) {
278- for (int j = 0 ; j < i; j++) {
279- if (map_restore_helper[j]) {
280- map[map_restore_helper[j]]++;
281- // reset map_restore_helper
282- map_restore_helper[j] = 0 ;
383+ if constexpr (need_recovery_map) {
384+ for (size_t j = 0 ; j < i; j++) {
385+ if (map_restore_helper[j]) {
386+ map[map_restore_helper[j]]++;
387+ map_restore_helper[j] = 0 ;
388+ }
283389 }
284390 }
285391 }
286392
287393 return needle_gram_count;
288394 }
289-
290- void inline static tolower (const Slice& str, std::string& buf) {
291- buf.assign (str.get_data (), str.get_size ());
292- std::transform (buf.begin (), buf.end (), buf.begin (), [](unsigned char c) { return std::tolower (c); });
293- }
294-
295- static NgramHash getAsciiHash (const Gram* ch, size_t gram_num) {
296- return crc_hash_32 (ch, gram_num, CRC_HASH_SEEDS::CRC_HASH_SEED1) & (0xffffu );
297- }
298395};
299396
397+ // Wrapper functions that check the UTF-8 flag at runtime and dispatch to the correct implementation
300398StatusOr<ColumnPtr> StringFunctions::ngram_search (FunctionContext* context, const Columns& columns) {
399+ auto state = reinterpret_cast <Ngramstate*>(context->get_function_state (FunctionContext::FRAGMENT_LOCAL));
400+ if (state && state->use_utf8 ) {
401+ return NgramFunctionImpl<false , true , char >::ngram_search_impl (context, columns);
402+ }
301403 return NgramFunctionImpl<false , false , char >::ngram_search_impl (context, columns);
302404}
303405
304406StatusOr<ColumnPtr> StringFunctions::ngram_search_case_insensitive (FunctionContext* context, const Columns& columns) {
407+ auto state = reinterpret_cast <Ngramstate*>(context->get_function_state (FunctionContext::FRAGMENT_LOCAL));
408+ if (state && state->use_utf8 ) {
409+ return NgramFunctionImpl<true , true , char >::ngram_search_impl (context, columns);
410+ }
305411 return NgramFunctionImpl<true , false , char >::ngram_search_impl (context, columns);
306412}
307413
308414Status StringFunctions::ngram_search_prepare (FunctionContext* context, FunctionContext::FunctionStateScope scope) {
415+ if (context->state () && context->state ()->ngram_search_support_utf8 ()) {
416+ return NgramFunctionImpl<false , true , char >::ngram_search_prepare_impl (context, scope);
417+ }
309418 return NgramFunctionImpl<false , false , char >::ngram_search_prepare_impl (context, scope);
310419}
311420
312421Status StringFunctions::ngram_search_case_insensitive_prepare (FunctionContext* context,
313422 FunctionContext::FunctionStateScope scope) {
423+ if (context->state () && context->state ()->ngram_search_support_utf8 ()) {
424+ return NgramFunctionImpl<true , true , char >::ngram_search_prepare_impl (context, scope);
425+ }
314426 return NgramFunctionImpl<true , false , char >::ngram_search_prepare_impl (context, scope);
315427}
316428
@@ -325,4 +437,4 @@ Status StringFunctions::ngram_search_close(FunctionContext* context, FunctionCon
325437 return Status::OK ();
326438}
327439
328- } // namespace starrocks
440+ } // namespace starrocks
0 commit comments