Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions be/src/exprs/function_call_expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,15 +244,25 @@ bool VectorizedFunctionCallExpr::ngram_bloom_filter(ExprContext* context, const
const auto& needle_column = fn_ctx->get_constant_column(1);
std::string needle = ColumnHelper::get_const_value<TYPE_VARCHAR>(needle_column).to_string();

if (!simdjson::validate_utf8(needle.data(), needle.size())) {
index_useful = false;
ngram_state->initialized = true;
ngram_state->index_useful = index_useful;
return true;
}

// for case_insensitive, we need to convert needle to lower case
if (!reader_options.index_case_sensitive) {
std::transform(needle.begin(), needle.end(), needle.begin(),
[](unsigned char c) { return std::tolower(c); });
std::string lower_needle;
if (validate_ascii_fast(needle.data(), needle.size())) {
Slice(needle).tolower(lower_needle);
} else {
utf8_tolower(needle, lower_needle);
}
needle = std::move(lower_needle);
}

if (!simdjson::validate_utf8(needle.data(), needle.size())) {
index_useful = false;
} else if (_fn_desc->name == "LIKE") {
if (_fn_desc->name == "LIKE") {
index_useful = split_like_string_to_ngram(needle, reader_options, ngram_set);
} else {
index_useful = split_normal_string_to_ngram(needle, fn_ctx, reader_options, ngram_set, _fn_desc->name);
Expand Down
216 changes: 166 additions & 50 deletions be/src/exprs/ngram.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
#include "exprs/function_context.h"
#include "exprs/string_functions.h"
#include "gutil/strings/fastmem.h"
#include "runtime/runtime_state.h"
#include "util/utf8.h"

namespace starrocks {
static constexpr size_t MAX_STRING_SIZE = 1 << 15;
// uint16[2^16] can almost fit into L2
Expand All @@ -41,6 +44,9 @@ struct Ngramstate {

float result = -1;

// Flag to indicate whether UTF-8 mode is enabled (set in prepare from template parameter)
bool use_utf8 = false;

std::vector<NgramHash>* get_or_create_driver_hashmap() {
std::thread::id current_thread_id = std::this_thread::get_id();

Expand Down Expand Up @@ -85,12 +91,21 @@ class NgramFunctionImpl {
return Status::NotSupported("ngram search's third parameter must be a positive number");
}

auto state = reinterpret_cast<Ngramstate*>(context->get_function_state(FunctionContext::FRAGMENT_LOCAL));

// For UTF-8 mode, check character count instead of byte count
size_t needle_char_count;
if constexpr (use_utf_8) {
needle_char_count = utf8_len(needle.get_data(), needle.get_data() + needle.get_size());
} else {
needle_char_count = needle.get_size();
}

// needle is too small so we can not get even single Ngram, so they are not similar at all
if (needle.get_size() < gram_num) {
if (needle_char_count < gram_num) {
return ColumnHelper::create_const_column<TYPE_DOUBLE>(0, haystack_column->size());
}

auto state = reinterpret_cast<Ngramstate*>(context->get_function_state(FunctionContext::FRAGMENT_LOCAL));
std::vector<NgramHash>* map = state->get_or_create_driver_hashmap();
if (haystack_column->is_constant()) {
if (context->is_constant_column(0)) {
Expand Down Expand Up @@ -119,6 +134,7 @@ class NgramFunctionImpl {
}

auto* state = new Ngramstate(MAP_SIZE);
state->use_utf8 = use_utf_8;

context->set_function_state(scope, state);

Expand Down Expand Up @@ -148,26 +164,71 @@ class NgramFunctionImpl {
}

private:
// Get UTF-8 character positions for a string
static void get_utf8_positions(const char* data, size_t len, std::vector<size_t>& positions) {
positions.clear();
for (size_t i = 0; i < len;) {
positions.push_back(i);
i += UTF8_BYTE_LENGTH_TABLE[static_cast<uint8_t>(data[i])];
}
}

// UTF-8 aware tolower - uses shared implementation from util/utf8.h
static void tolower_utf8(const Slice& str, std::string& buf) {
if (validate_ascii_fast(str.get_data(), str.get_size())) {
Slice(str.get_data(), str.get_size()).tolower(buf);
} else {
utf8_tolower(str.get_data(), str.get_size(), buf);
}
}

// for every gram of needle, we calculate its' hash value and store its' frequency in map, and return the number of gram in needle
size_t static calculateMapWithNeedle(std::vector<NgramHash>& map, const Slice& needle, size_t gram_num) {
size_t needle_length = needle.get_size();
NgramHash cur_hash;
size_t i;
Slice cur_needle(needle.get_data(), needle_length);
const char* cur_char_ptr;
Slice cur_needle(needle.get_data(), needle.get_size());
std::string buf;
if constexpr (case_insensitive) {
tolower(needle, buf);
if constexpr (use_utf_8) {
tolower_utf8(needle, buf);
} else {
buf.assign(needle.get_data(), needle.get_size());
std::transform(buf.begin(), buf.end(), buf.begin(), [](unsigned char c) { return std::tolower(c); });
}
cur_needle = Slice(buf.c_str(), buf.size());
}
cur_char_ptr = cur_needle.get_data();

for (i = 0; i + gram_num <= needle_length; i++) {
cur_hash = getAsciiHash(cur_char_ptr + i, gram_num);
map[cur_hash]++;
}
const char* data = cur_needle.get_data();
size_t len = cur_needle.get_size();

if constexpr (use_utf_8) {
// UTF-8 mode: iterate by characters
std::vector<size_t> positions;
get_utf8_positions(data, len, positions);

return i;
size_t num_chars = positions.size();
if (num_chars < gram_num) {
return 0;
}

size_t gram_count = 0;
for (size_t i = 0; i + gram_num <= num_chars; i++) {
size_t start = positions[i];
size_t end = (i + gram_num < num_chars) ? positions[i + gram_num] : len;
size_t ngram_bytes = end - start;

NgramHash cur_hash = crc_hash_32(data + start, ngram_bytes, CRC_HASH_SEEDS::CRC_HASH_SEED1) & (0xffffu);
map[cur_hash]++;
gram_count++;
}
return gram_count;
} else {
// ASCII mode: iterate by bytes (original behavior)
size_t i;
for (i = 0; i + gram_num <= len; i++) {
NgramHash cur_hash = crc_hash_32(data + i, gram_num, CRC_HASH_SEEDS::CRC_HASH_SEED1) & (0xffffu);
map[cur_hash]++;
}
return i;
}
}

ColumnPtr static haystack_vector_and_needle_const(const ColumnPtr& haystack_column, std::vector<NgramHash>& map,
Expand All @@ -176,17 +237,17 @@ class NgramFunctionImpl {

NullColumnPtr res_null = nullptr;
ColumnPtr haystackPtr = nullptr;
// used in case_insensitive
StatusOr<ColumnPtr> lower;
if (haystack_column->is_nullable()) {
auto haystack_nullable = ColumnHelper::as_column<NullableColumn>(haystack_column);
res_null = haystack_nullable->null_column();
haystackPtr = haystack_nullable->data_column();
} else {
haystackPtr = haystack_column;
}
if constexpr (case_insensitive) {
// @TODO if ngram supports utf8 in the future, we should use antoher implementation.

// For case-insensitive ASCII mode, use the fast StringCaseToggleFunction
// For UTF-8 mode, we handle case conversion per-string in calculateDistanceWithHaystack
if constexpr (case_insensitive && !use_utf_8) {
haystackPtr = StringCaseToggleFunction<false>::evaluate<TYPE_VARCHAR, TYPE_VARCHAR>(haystackPtr);
}

Expand Down Expand Up @@ -235,7 +296,12 @@ class NgramFunctionImpl {

std::string buf;
if constexpr (case_insensitive) {
tolower(haystack, buf);
if constexpr (use_utf_8) {
tolower_utf8(haystack, buf);
} else {
buf.assign(haystack.get_data(), haystack.get_size());
std::transform(buf.begin(), buf.end(), buf.begin(), [](unsigned char c) { return std::tolower(c); });
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Redundant UTF-8 lowercase conversion in const+const code path

In UTF-8 case-insensitive mode, when both haystack and needle are constant, haystack_const_and_needle_const calls tolower_utf8 on the haystack, then passes the result to calculateDistanceWithHaystack, which calls tolower_utf8 again on the already-lowercased input. This double conversion is wasteful since ICU calls have overhead. The haystack_vector_and_needle_const function correctly avoids pre-conversion for UTF-8 mode (relying on calculateDistanceWithHaystack to handle it), but haystack_const_and_needle_const doesn't follow the same pattern.

Additional Locations (1)

Fix in Cursor Fix in Web

cur_haystack = Slice(buf.c_str(), buf.size());
}

Expand All @@ -250,67 +316,117 @@ class NgramFunctionImpl {
return result;
}

// traverse haystacks every gram, find whether this gram is in needle or not using gram's hash
// traverse haystack's every gram, find whether this gram is in needle or not using gram's hash
// 16bit hash value may cause hash collision, but because we just calculate the similarity of two string
// so don't need to be so accurate.
template <bool need_recovery_map>
size_t static calculateDistanceWithHaystack(std::vector<NgramHash>& map, const Slice& haystack,
[[maybe_unused]] std::vector<NgramHash>& map_restore_helper,
size_t needle_gram_count, size_t gram_num) {
size_t haystack_length = haystack.get_size();
NgramHash cur_hash;
size_t i;
const char* ptr = haystack.get_data();

for (i = 0; i + gram_num <= haystack_length; i++) {
cur_hash = getAsciiHash(ptr + i, gram_num);
// if this gram is in needle
if (map[cur_hash] > 0) {
needle_gram_count--;
map[cur_hash]--;
if constexpr (need_recovery_map) {
map_restore_helper[i] = cur_hash;
// For UTF-8 case-insensitive mode in vector processing, we need to convert here
std::string lower_buf;
Slice cur_haystack = haystack;
if constexpr (case_insensitive && use_utf_8) {
tolower_utf8(haystack, lower_buf);
cur_haystack = Slice(lower_buf.c_str(), lower_buf.size());
}

const char* data = cur_haystack.get_data();
size_t len = cur_haystack.get_size();

if constexpr (use_utf_8) {
// UTF-8 mode: iterate by characters
std::vector<size_t> positions;
get_utf8_positions(data, len, positions);

size_t num_chars = positions.size();
if (num_chars < gram_num) {
return needle_gram_count;
}

// For UTF-8 mode, we use positions as indices in map_restore_helper
size_t gram_idx = 0;
for (size_t i = 0; i + gram_num <= num_chars; i++, gram_idx++) {
size_t start = positions[i];
size_t end = (i + gram_num < num_chars) ? positions[i + gram_num] : len;
size_t ngram_bytes = end - start;

NgramHash cur_hash = crc_hash_32(data + start, ngram_bytes, CRC_HASH_SEEDS::CRC_HASH_SEED1) & (0xffffu);

if (map[cur_hash] > 0) {
needle_gram_count--;
map[cur_hash]--;
if constexpr (need_recovery_map) {
map_restore_helper[gram_idx] = cur_hash;
}
}
}
}

if constexpr (need_recovery_map) {
for (int j = 0; j < i; j++) {
if (map_restore_helper[j]) {
map[map_restore_helper[j]]++;
// reset map_restore_helper
map_restore_helper[j] = 0;
if constexpr (need_recovery_map) {
for (size_t j = 0; j < gram_idx; j++) {
if (map_restore_helper[j]) {
map[map_restore_helper[j]]++;
map_restore_helper[j] = 0;
}
}
}
} else {
// ASCII mode: iterate by bytes (original behavior)
size_t i;
for (i = 0; i + gram_num <= len; i++) {
NgramHash cur_hash = crc_hash_32(data + i, gram_num, CRC_HASH_SEEDS::CRC_HASH_SEED1) & (0xffffu);
if (map[cur_hash] > 0) {
needle_gram_count--;
map[cur_hash]--;
if constexpr (need_recovery_map) {
map_restore_helper[i] = cur_hash;
}
}
}

if constexpr (need_recovery_map) {
for (size_t j = 0; j < i; j++) {
if (map_restore_helper[j]) {
map[map_restore_helper[j]]++;
map_restore_helper[j] = 0;
}
}
}
}

return needle_gram_count;
}

void inline static tolower(const Slice& str, std::string& buf) {
buf.assign(str.get_data(), str.get_size());
std::transform(buf.begin(), buf.end(), buf.begin(), [](unsigned char c) { return std::tolower(c); });
}

static NgramHash getAsciiHash(const Gram* ch, size_t gram_num) {
return crc_hash_32(ch, gram_num, CRC_HASH_SEEDS::CRC_HASH_SEED1) & (0xffffu);
}
};

// Wrapper functions that check the UTF-8 flag at runtime and dispatch to the correct implementation
StatusOr<ColumnPtr> StringFunctions::ngram_search(FunctionContext* context, const Columns& columns) {
auto state = reinterpret_cast<Ngramstate*>(context->get_function_state(FunctionContext::FRAGMENT_LOCAL));
if (state && state->use_utf8) {
return NgramFunctionImpl<false, true, char>::ngram_search_impl(context, columns);
}
return NgramFunctionImpl<false, false, char>::ngram_search_impl(context, columns);
}

StatusOr<ColumnPtr> StringFunctions::ngram_search_case_insensitive(FunctionContext* context, const Columns& columns) {
auto state = reinterpret_cast<Ngramstate*>(context->get_function_state(FunctionContext::FRAGMENT_LOCAL));
if (state && state->use_utf8) {
return NgramFunctionImpl<true, true, char>::ngram_search_impl(context, columns);
}
return NgramFunctionImpl<true, false, char>::ngram_search_impl(context, columns);
}

Status StringFunctions::ngram_search_prepare(FunctionContext* context, FunctionContext::FunctionStateScope scope) {
if (context->state() && context->state()->ngram_search_support_utf8()) {
return NgramFunctionImpl<false, true, char>::ngram_search_prepare_impl(context, scope);
}
return NgramFunctionImpl<false, false, char>::ngram_search_prepare_impl(context, scope);
}

Status StringFunctions::ngram_search_case_insensitive_prepare(FunctionContext* context,
FunctionContext::FunctionStateScope scope) {
if (context->state() && context->state()->ngram_search_support_utf8()) {
return NgramFunctionImpl<true, true, char>::ngram_search_prepare_impl(context, scope);
}
return NgramFunctionImpl<true, false, char>::ngram_search_prepare_impl(context, scope);
}

Expand All @@ -325,4 +441,4 @@ Status StringFunctions::ngram_search_close(FunctionContext* context, FunctionCon
return Status::OK();
}

} // namespace starrocks
} // namespace starrocks
4 changes: 4 additions & 0 deletions be/src/runtime/runtime_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,10 @@ class RuntimeState {
return _query_options.__isset.lower_upper_support_utf8 && _query_options.lower_upper_support_utf8;
}

bool ngram_search_support_utf8() const {
return _query_options.__isset.ngram_search_support_utf8 && _query_options.ngram_search_support_utf8;
}

bool enable_global_late_materialization() const {
return _query_options.__isset.enable_global_late_materialization &&
_query_options.enable_global_late_materialization;
Expand Down
12 changes: 10 additions & 2 deletions be/src/storage/rowset/bloom_filter_index_writer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,9 @@ class NgramBloomFilterIndexWriterImpl<field_type, std::enable_if_t<is_slice_type
for (int i = 0; i < count; ++i) {
std::vector<size_t> index;
size_t slice_gram_num = get_utf8_index(*cur_slice, &index);
// slice_gram_num can be used to judge if the cur_slice is an ASCII string or not.
// If slice_gram_num == cur_slice->get_size(), then it is.
bool is_ascii = (slice_gram_num == cur_slice->get_size());

size_t j;
for (j = 0; j + gram_num <= slice_gram_num; j++) {
Expand All @@ -226,9 +229,14 @@ class NgramBloomFilterIndexWriterImpl<field_type, std::enable_if_t<is_slice_type
if (this->_bf_options.case_sensitive) {
_values.insert(get_value<field_type>(&cur_ngram, this->_typeinfo, &this->_pool));
} else {
// todo::exist two copy of ngram, need to optimize
std::string lower_ngram;
Slice lower_ngram_slice = cur_ngram.tolower(lower_ngram);
Slice lower_ngram_slice;
if (is_ascii) {
lower_ngram_slice = cur_ngram.tolower(lower_ngram);
} else {
utf8_tolower(cur_ngram.get_data(), cur_ngram.get_size(), lower_ngram);
lower_ngram_slice = Slice(lower_ngram.data(), lower_ngram.size());
}
_values.insert(get_value<field_type>(&lower_ngram_slice, this->_typeinfo, &this->_pool));
}
}
Expand Down
Loading
Loading