Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 203 additions & 0 deletions be/src/exprs/string_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "exprs/string_functions.h"

#include "column/bytes.h"
#include "function_context.h"
#include "util/defer_op.h"

#ifdef __x86_64__
Expand Down Expand Up @@ -4446,6 +4447,208 @@ StatusOr<ColumnPtr> StringFunctions::regexp_count(FunctionContext* context, cons
}
}

Status StringFunctions::regexp_position_prepare(FunctionContext* context,
FunctionContext::FunctionStateScope scope) {
if (scope != FunctionContext::FRAGMENT_LOCAL) {
return Status::OK();
}

auto* state = new StringFunctionsState();
context->set_function_state(scope, state);

// check if pattern is constant
if (context->is_constant_column(1)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (context->is_constant_column(1)) {
if (!context->is_constant_column(1)) {
return Status::OK();
}

const auto pattern_col = context->get_constant_column(1);
if (!pattern_col->only_null()) {
Slice pattern = ColumnHelper::get_const_value<TYPE_VARCHAR>(pattern_col);
state->pattern = std::string(pattern.data, pattern.size);
state->const_pattern = true;

state->options = std::make_unique<re2::RE2::Options>();
state->options->set_log_errors(false);

state->regex = std::make_unique<re2::RE2>(state->pattern, *state->options);
if (!state->regex->ok()) {
std::stringstream error;
error << "Invalid regex expression: " << state->pattern;
context->set_error(error.str().c_str());
return Status::InvalidArgument(error.str());
}
}
}
return Status::OK();

}

Status StringFunctions::regexp_position_close(FunctionContext* context,
FunctionContext::FunctionStateScope scope) {
if (scope == FunctionContext::FRAGMENT_LOCAL) {
auto* state = reinterpret_cast<StringFunctionsState*>(
context->get_function_state(FunctionContext::FRAGMENT_LOCAL));
delete state;
}
return Status::OK();
}

static ColumnPtr regexp_position_const_pattern(re2::RE2* const_re, const Columns& columns) {
auto str_viewer = ColumnViewer<TYPE_VARCHAR>(columns[0]);
auto start_viewer = ColumnViewer<TYPE_INT>(columns[2]);
auto occurrence_viewer = ColumnViewer<TYPE_INT>(columns[3]);

auto size = columns[0]->size();
ColumnBuilder<TYPE_INT> result(size);

for (int row = 0; row < size; ++row) {
if (str_viewer.is_null(row) || start_viewer.is_null(row) || occurrence_viewer.is_null(row)) {
result.append_null();
continue;
}

int start_pos = start_viewer.value(row);
int occurrence = occurrence_viewer.value(row);

if (start_pos < 1 || occurrence < 1) {
result.append(-1);
continue;
}

auto str_value = str_viewer.value(row);
// get the number of code points in the string
int utf8_length = utf8_len(str_value.data, str_value.data + str_value.size);

if (start_pos > utf8_length) {
result.append(-1);
continue;
}

const char* search_start = skip_leading_utf8(str_value.data, str_value.data + str_value.size, start_pos - 1);
int byte_offset = search_start - str_value.data;

int count = 0;
re2::StringPiece str_sp(str_value.data, str_value.size);
re2::StringPiece match;

bool found = false;
while (byte_offset <= str_value.size) {
if (const_re->Match(str_sp, byte_offset, str_value.size, re2::RE2::UNANCHORED, &match, 1)) {
count++;
if (count == occurrence) {
// convert back to one-based index
int code_point_pos = utf8_len(str_value.data, match.data()) + 1;
result.append(code_point_pos);
found = true;
break;
}

byte_offset = match.data() - str_value.data + match.size();
if (match.size() == 0) {
byte_offset++;
}
} else {
break;
}
}

if (!found) {
result.append(-1);
}
}

return result.build(ColumnHelper::is_all_const(columns));
}

static StatusOr<ColumnPtr> regexp_position_general(FunctionContext* context, re2::RE2::Options* options, const Columns& columns) {
auto str_viewer = ColumnViewer<TYPE_VARCHAR>(columns[0]);
auto pattern_viewer = ColumnViewer<TYPE_VARCHAR>(columns[1]);
auto start_viewer = ColumnViewer<TYPE_INT>(columns[2]);
auto occurrence_viewer = ColumnViewer<TYPE_INT>(columns[3]);

auto size = columns[0]->size();
ColumnBuilder<TYPE_INT> result(size);

for (int row = 0; row < size; ++row) {
if (str_viewer.is_null(row) || pattern_viewer.is_null(row) ||
start_viewer.is_null(row) || occurrence_viewer.is_null(row)) {
result.append_null();
continue;
}

int start_pos = start_viewer.value(row);
int occurrence = occurrence_viewer.value(row);

if (start_pos < 1 || occurrence < 1) {
result.append(-1);
continue;
}

auto str_value = str_viewer.value(row);
auto pattern_value = pattern_viewer.value(row);

std::string pattern_str = pattern_value.to_string();
// compile the pattern for each new row, keep in stack memory
re2::RE2 local_re(pattern_str, *options);
if (!local_re.ok()) {
return Status::InvalidArgument(
strings::Substitute("Invalid regex expression: $0", pattern_str));
}

int utf8_length = utf8_len(str_value.data, str_value.data + str_value.size);

if (start_pos > utf8_length) {
result.append(-1);
continue;
}

const char* search_start = skip_leading_utf8(str_value.data, str_value.data + str_value.size, start_pos - 1);
int byte_offset = search_start - str_value.data;

int count = 0;
re2::StringPiece str_sp(str_value.data, str_value.size);
re2::StringPiece match;

bool found = false;
while (byte_offset <= str_value.size) {
if (local_re.Match(str_sp, byte_offset, str_value.size, re2::RE2::UNANCHORED, &match, 1)) {
count++;
if (count == occurrence) {
int code_point_pos = utf8_len(str_value.data, match.data()) + 1;
result.append(code_point_pos);
found = true;
break;
}

byte_offset = match.data() - str_value.data + match.size();
if (match.size() == 0) {
byte_offset++;
}
} else {
break;
}
}

if (!found) {
result.append(-1);
}
}

return result.build(ColumnHelper::is_all_const(columns));
}

StatusOr<ColumnPtr> StringFunctions::regexp_position(FunctionContext* context, const Columns& columns) {
RETURN_IF_COLUMNS_ONLY_NULL(columns);

auto* state = reinterpret_cast<StringFunctionsState*>(context->get_function_state(FunctionContext::FRAGMENT_LOCAL));

if (state && state->const_pattern && state->regex) {
// Const col
return regexp_position_const_pattern(state->get_or_prepare_regex(), columns);
} else {
re2::RE2::Options options;
options.set_log_errors(false);
return regexp_position_general(context, &options, columns);
}
}

struct ReplaceState {
bool only_null{false};

Expand Down
9 changes: 9 additions & 0 deletions be/src/exprs/string_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,15 @@ class StringFunctions {
*/
DEFINE_VECTORIZED_FN(regexp_count);

/**
* @param: [string_value, pattern_value, start_position, occurrence]
* @paramType: [BinaryColumn, BinaryColumn, IntColumn, IntColumn]
* @return: IntColumn
*/
DEFINE_VECTORIZED_FN(regexp_position);
static Status regexp_position_prepare(FunctionContext* context, FunctionContext::FunctionStateScope scope);
static Status regexp_position_close(FunctionContext* context, FunctionContext::FunctionStateScope scope);

/**
* @param: [string_value, pattern_value, replace_value]
* @paramType: [BinaryColumn, BinaryColumn, BinaryColumn]
Expand Down
Loading
Loading