Skip to content

Commit b8d7f28

Browse files
Chew Zi Xuanchewbum
authored andcommitted
Implement Regexp_Position in C++ mode
1 parent 527756e commit b8d7f28

File tree

8 files changed

+555
-0
lines changed

8 files changed

+555
-0
lines changed

be/src/exprs/string_functions.cpp

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "exprs/string_functions.h"
1616

1717
#include "column/bytes.h"
18+
#include "function_context.h"
1819
#include "util/defer_op.h"
1920

2021
#ifdef __x86_64__
@@ -4368,6 +4369,208 @@ StatusOr<ColumnPtr> StringFunctions::regexp_count(FunctionContext* context, cons
43684369
}
43694370
}
43704371

4372+
Status StringFunctions::regexp_position_prepare(FunctionContext* context,
4373+
FunctionContext::FunctionStateScope scope) {
4374+
if (scope != FunctionContext::FRAGMENT_LOCAL) {
4375+
return Status::OK();
4376+
}
4377+
4378+
auto* state = new StringFunctionsState();
4379+
context->set_function_state(scope, state);
4380+
4381+
// check if pattern is constant
4382+
if (context->is_constant_column(1)) {
4383+
const auto pattern_col = context->get_constant_column(1);
4384+
if (!pattern_col->only_null()) {
4385+
Slice pattern = ColumnHelper::get_const_value<TYPE_VARCHAR>(pattern_col);
4386+
state->pattern = std::string(pattern.data, pattern.size);
4387+
state->const_pattern = true;
4388+
4389+
state->options = std::make_unique<re2::RE2::Options>();
4390+
state->options->set_log_errors(false);
4391+
4392+
state->regex = std::make_unique<re2::RE2>(state->pattern, *state->options);
4393+
if (!state->regex->ok()) {
4394+
std::stringstream error;
4395+
error << "Invalid regex expression: " << state->pattern;
4396+
context->set_error(error.str().c_str());
4397+
return Status::InvalidArgument(error.str());
4398+
}
4399+
}
4400+
}
4401+
return Status::OK();
4402+
4403+
}
4404+
4405+
Status StringFunctions::regexp_position_close(FunctionContext* context,
4406+
FunctionContext::FunctionStateScope scope) {
4407+
if (scope == FunctionContext::FRAGMENT_LOCAL) {
4408+
auto* state = reinterpret_cast<StringFunctionsState*>(
4409+
context->get_function_state(FunctionContext::FRAGMENT_LOCAL));
4410+
delete state;
4411+
}
4412+
return Status::OK();
4413+
}
4414+
4415+
static ColumnPtr regexp_position_const_pattern(re2::RE2* const_re, const Columns& columns) {
4416+
auto str_viewer = ColumnViewer<TYPE_VARCHAR>(columns[0]);
4417+
auto start_viewer = ColumnViewer<TYPE_INT>(columns[2]);
4418+
auto occurrence_viewer = ColumnViewer<TYPE_INT>(columns[3]);
4419+
4420+
auto size = columns[0]->size();
4421+
ColumnBuilder<TYPE_INT> result(size);
4422+
4423+
for (int row = 0; row < size; ++row) {
4424+
if (str_viewer.is_null(row) || start_viewer.is_null(row) || occurrence_viewer.is_null(row)) {
4425+
result.append_null();
4426+
continue;
4427+
}
4428+
4429+
int start_pos = start_viewer.value(row);
4430+
int occurrence = occurrence_viewer.value(row);
4431+
4432+
if (start_pos < 1 || occurrence < 1) {
4433+
result.append(-1);
4434+
continue;
4435+
}
4436+
4437+
auto str_value = str_viewer.value(row);
4438+
// get the number of code points in the string
4439+
int utf8_length = utf8_len(str_value.data, str_value.data + str_value.size);
4440+
4441+
if (start_pos > utf8_length) {
4442+
result.append(-1);
4443+
continue;
4444+
}
4445+
4446+
const char* search_start = skip_leading_utf8(str_value.data, str_value.data + str_value.size, start_pos - 1);
4447+
int byte_offset = search_start - str_value.data;
4448+
4449+
int count = 0;
4450+
re2::StringPiece str_sp(str_value.data, str_value.size);
4451+
re2::StringPiece match;
4452+
4453+
bool found = false;
4454+
while (byte_offset <= str_value.size) {
4455+
if (const_re->Match(str_sp, byte_offset, str_value.size, re2::RE2::UNANCHORED, &match, 1)) {
4456+
count++;
4457+
if (count == occurrence) {
4458+
// convert back to one-based index
4459+
int code_point_pos = utf8_len(str_value.data, match.data()) + 1;
4460+
result.append(code_point_pos);
4461+
found = true;
4462+
break;
4463+
}
4464+
4465+
byte_offset = match.data() - str_value.data + match.size();
4466+
if (match.size() == 0) {
4467+
byte_offset++;
4468+
}
4469+
} else {
4470+
break;
4471+
}
4472+
}
4473+
4474+
if (!found) {
4475+
result.append(-1);
4476+
}
4477+
}
4478+
4479+
return result.build(ColumnHelper::is_all_const(columns));
4480+
}
4481+
4482+
static StatusOr<ColumnPtr> regexp_position_general(FunctionContext* context, re2::RE2::Options* options, const Columns& columns) {
4483+
auto str_viewer = ColumnViewer<TYPE_VARCHAR>(columns[0]);
4484+
auto pattern_viewer = ColumnViewer<TYPE_VARCHAR>(columns[1]);
4485+
auto start_viewer = ColumnViewer<TYPE_INT>(columns[2]);
4486+
auto occurrence_viewer = ColumnViewer<TYPE_INT>(columns[3]);
4487+
4488+
auto size = columns[0]->size();
4489+
ColumnBuilder<TYPE_INT> result(size);
4490+
4491+
for (int row = 0; row < size; ++row) {
4492+
if (str_viewer.is_null(row) || pattern_viewer.is_null(row) ||
4493+
start_viewer.is_null(row) || occurrence_viewer.is_null(row)) {
4494+
result.append_null();
4495+
continue;
4496+
}
4497+
4498+
int start_pos = start_viewer.value(row);
4499+
int occurrence = occurrence_viewer.value(row);
4500+
4501+
if (start_pos < 1 || occurrence < 1) {
4502+
result.append(-1);
4503+
continue;
4504+
}
4505+
4506+
auto str_value = str_viewer.value(row);
4507+
auto pattern_value = pattern_viewer.value(row);
4508+
4509+
std::string pattern_str = pattern_value.to_string();
4510+
// compile the pattern for each new row, keep in stack memory
4511+
re2::RE2 local_re(pattern_str, *options);
4512+
if (!local_re.ok()) {
4513+
return Status::InvalidArgument(
4514+
strings::Substitute("Invalid regex expression: $0", pattern_str));
4515+
}
4516+
4517+
int utf8_length = utf8_len(str_value.data, str_value.data + str_value.size);
4518+
4519+
if (start_pos > utf8_length) {
4520+
result.append(-1);
4521+
continue;
4522+
}
4523+
4524+
const char* search_start = skip_leading_utf8(str_value.data, str_value.data + str_value.size, start_pos - 1);
4525+
int byte_offset = search_start - str_value.data;
4526+
4527+
int count = 0;
4528+
re2::StringPiece str_sp(str_value.data, str_value.size);
4529+
re2::StringPiece match;
4530+
4531+
bool found = false;
4532+
while (byte_offset <= str_value.size) {
4533+
if (local_re.Match(str_sp, byte_offset, str_value.size, re2::RE2::UNANCHORED, &match, 1)) {
4534+
count++;
4535+
if (count == occurrence) {
4536+
int code_point_pos = utf8_len(str_value.data, match.data()) + 1;
4537+
result.append(code_point_pos);
4538+
found = true;
4539+
break;
4540+
}
4541+
4542+
byte_offset = match.data() - str_value.data + match.size();
4543+
if (match.size() == 0) {
4544+
byte_offset++;
4545+
}
4546+
} else {
4547+
break;
4548+
}
4549+
}
4550+
4551+
if (!found) {
4552+
result.append(-1);
4553+
}
4554+
}
4555+
4556+
return result.build(ColumnHelper::is_all_const(columns));
4557+
}
4558+
4559+
StatusOr<ColumnPtr> StringFunctions::regexp_position(FunctionContext* context, const Columns& columns) {
4560+
RETURN_IF_COLUMNS_ONLY_NULL(columns);
4561+
4562+
auto* state = reinterpret_cast<StringFunctionsState*>(context->get_function_state(FunctionContext::FRAGMENT_LOCAL));
4563+
4564+
if (state && state->const_pattern && state->regex) {
4565+
// Const col
4566+
return regexp_position_const_pattern(state->get_or_prepare_regex(), columns);
4567+
} else {
4568+
re2::RE2::Options options;
4569+
options.set_log_errors(false);
4570+
return regexp_position_general(context, &options, columns);
4571+
}
4572+
}
4573+
43714574
struct ReplaceState {
43724575
bool only_null{false};
43734576

be/src/exprs/string_functions.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,15 @@ class StringFunctions {
460460
*/
461461
DEFINE_VECTORIZED_FN(regexp_count);
462462

463+
/**
464+
* @param: [string_value, pattern_value, start_position, occurrence]
465+
* @paramType: [BinaryColumn, BinaryColumn, IntColumn, IntColumn]
466+
* @return: IntColumn
467+
*/
468+
DEFINE_VECTORIZED_FN(regexp_position);
469+
static Status regexp_position_prepare(FunctionContext* context, FunctionContext::FunctionStateScope scope);
470+
static Status regexp_position_close(FunctionContext* context, FunctionContext::FunctionStateScope scope);
471+
463472
/**
464473
* @param: [string_value, pattern_value, replace_value]
465474
* @paramType: [BinaryColumn, BinaryColumn, BinaryColumn]

0 commit comments

Comments
 (0)