|
15 | 15 | #include "exprs/string_functions.h" |
16 | 16 |
|
17 | 17 | #include "column/bytes.h" |
| 18 | +#include "function_context.h" |
18 | 19 | #include "util/defer_op.h" |
19 | 20 |
|
20 | 21 | #ifdef __x86_64__ |
@@ -4368,6 +4369,208 @@ StatusOr<ColumnPtr> StringFunctions::regexp_count(FunctionContext* context, cons |
4368 | 4369 | } |
4369 | 4370 | } |
4370 | 4371 |
|
| 4372 | +Status StringFunctions::regexp_position_prepare(FunctionContext* context, |
| 4373 | + FunctionContext::FunctionStateScope scope) { |
| 4374 | + if (scope != FunctionContext::FRAGMENT_LOCAL) { |
| 4375 | + return Status::OK(); |
| 4376 | + } |
| 4377 | + |
| 4378 | + auto* state = new StringFunctionsState(); |
| 4379 | + context->set_function_state(scope, state); |
| 4380 | + |
| 4381 | + // check if pattern is constant |
| 4382 | + if (context->is_constant_column(1)) { |
| 4383 | + const auto pattern_col = context->get_constant_column(1); |
| 4384 | + if (!pattern_col->only_null()) { |
| 4385 | + Slice pattern = ColumnHelper::get_const_value<TYPE_VARCHAR>(pattern_col); |
| 4386 | + state->pattern = std::string(pattern.data, pattern.size); |
| 4387 | + state->const_pattern = true; |
| 4388 | + |
| 4389 | + state->options = std::make_unique<re2::RE2::Options>(); |
| 4390 | + state->options->set_log_errors(false); |
| 4391 | + |
| 4392 | + state->regex = std::make_unique<re2::RE2>(state->pattern, *state->options); |
| 4393 | + if (!state->regex->ok()) { |
| 4394 | + std::stringstream error; |
| 4395 | + error << "Invalid regex expression: " << state->pattern; |
| 4396 | + context->set_error(error.str().c_str()); |
| 4397 | + return Status::InvalidArgument(error.str()); |
| 4398 | + } |
| 4399 | + } |
| 4400 | + } |
| 4401 | + return Status::OK(); |
| 4402 | + |
| 4403 | +} |
| 4404 | + |
| 4405 | +Status StringFunctions::regexp_position_close(FunctionContext* context, |
| 4406 | + FunctionContext::FunctionStateScope scope) { |
| 4407 | + if (scope == FunctionContext::FRAGMENT_LOCAL) { |
| 4408 | + auto* state = reinterpret_cast<StringFunctionsState*>( |
| 4409 | + context->get_function_state(FunctionContext::FRAGMENT_LOCAL)); |
| 4410 | + delete state; |
| 4411 | + } |
| 4412 | + return Status::OK(); |
| 4413 | +} |
| 4414 | + |
| 4415 | +static ColumnPtr regexp_position_const_pattern(re2::RE2* const_re, const Columns& columns) { |
| 4416 | + auto str_viewer = ColumnViewer<TYPE_VARCHAR>(columns[0]); |
| 4417 | + auto start_viewer = ColumnViewer<TYPE_INT>(columns[2]); |
| 4418 | + auto occurrence_viewer = ColumnViewer<TYPE_INT>(columns[3]); |
| 4419 | + |
| 4420 | + auto size = columns[0]->size(); |
| 4421 | + ColumnBuilder<TYPE_INT> result(size); |
| 4422 | + |
| 4423 | + for (int row = 0; row < size; ++row) { |
| 4424 | + if (str_viewer.is_null(row) || start_viewer.is_null(row) || occurrence_viewer.is_null(row)) { |
| 4425 | + result.append_null(); |
| 4426 | + continue; |
| 4427 | + } |
| 4428 | + |
| 4429 | + int start_pos = start_viewer.value(row); |
| 4430 | + int occurrence = occurrence_viewer.value(row); |
| 4431 | + |
| 4432 | + if (start_pos < 1 || occurrence < 1) { |
| 4433 | + result.append(-1); |
| 4434 | + continue; |
| 4435 | + } |
| 4436 | + |
| 4437 | + auto str_value = str_viewer.value(row); |
| 4438 | + // get the number of code points in the string |
| 4439 | + int utf8_length = utf8_len(str_value.data, str_value.data + str_value.size); |
| 4440 | + |
| 4441 | + if (start_pos > utf8_length) { |
| 4442 | + result.append(-1); |
| 4443 | + continue; |
| 4444 | + } |
| 4445 | + |
| 4446 | + const char* search_start = skip_leading_utf8(str_value.data, str_value.data + str_value.size, start_pos - 1); |
| 4447 | + int byte_offset = search_start - str_value.data; |
| 4448 | + |
| 4449 | + int count = 0; |
| 4450 | + re2::StringPiece str_sp(str_value.data, str_value.size); |
| 4451 | + re2::StringPiece match; |
| 4452 | + |
| 4453 | + bool found = false; |
| 4454 | + while (byte_offset <= str_value.size) { |
| 4455 | + if (const_re->Match(str_sp, byte_offset, str_value.size, re2::RE2::UNANCHORED, &match, 1)) { |
| 4456 | + count++; |
| 4457 | + if (count == occurrence) { |
| 4458 | + // convert back to one-based index |
| 4459 | + int code_point_pos = utf8_len(str_value.data, match.data()) + 1; |
| 4460 | + result.append(code_point_pos); |
| 4461 | + found = true; |
| 4462 | + break; |
| 4463 | + } |
| 4464 | + |
| 4465 | + byte_offset = match.data() - str_value.data + match.size(); |
| 4466 | + if (match.size() == 0) { |
| 4467 | + byte_offset++; |
| 4468 | + } |
| 4469 | + } else { |
| 4470 | + break; |
| 4471 | + } |
| 4472 | + } |
| 4473 | + |
| 4474 | + if (!found) { |
| 4475 | + result.append(-1); |
| 4476 | + } |
| 4477 | + } |
| 4478 | + |
| 4479 | + return result.build(ColumnHelper::is_all_const(columns)); |
| 4480 | +} |
| 4481 | + |
| 4482 | +static StatusOr<ColumnPtr> regexp_position_general(FunctionContext* context, re2::RE2::Options* options, const Columns& columns) { |
| 4483 | + auto str_viewer = ColumnViewer<TYPE_VARCHAR>(columns[0]); |
| 4484 | + auto pattern_viewer = ColumnViewer<TYPE_VARCHAR>(columns[1]); |
| 4485 | + auto start_viewer = ColumnViewer<TYPE_INT>(columns[2]); |
| 4486 | + auto occurrence_viewer = ColumnViewer<TYPE_INT>(columns[3]); |
| 4487 | + |
| 4488 | + auto size = columns[0]->size(); |
| 4489 | + ColumnBuilder<TYPE_INT> result(size); |
| 4490 | + |
| 4491 | + for (int row = 0; row < size; ++row) { |
| 4492 | + if (str_viewer.is_null(row) || pattern_viewer.is_null(row) || |
| 4493 | + start_viewer.is_null(row) || occurrence_viewer.is_null(row)) { |
| 4494 | + result.append_null(); |
| 4495 | + continue; |
| 4496 | + } |
| 4497 | + |
| 4498 | + int start_pos = start_viewer.value(row); |
| 4499 | + int occurrence = occurrence_viewer.value(row); |
| 4500 | + |
| 4501 | + if (start_pos < 1 || occurrence < 1) { |
| 4502 | + result.append(-1); |
| 4503 | + continue; |
| 4504 | + } |
| 4505 | + |
| 4506 | + auto str_value = str_viewer.value(row); |
| 4507 | + auto pattern_value = pattern_viewer.value(row); |
| 4508 | + |
| 4509 | + std::string pattern_str = pattern_value.to_string(); |
| 4510 | + // compile the pattern for each new row, keep in stack memory |
| 4511 | + re2::RE2 local_re(pattern_str, *options); |
| 4512 | + if (!local_re.ok()) { |
| 4513 | + return Status::InvalidArgument( |
| 4514 | + strings::Substitute("Invalid regex expression: $0", pattern_str)); |
| 4515 | + } |
| 4516 | + |
| 4517 | + int utf8_length = utf8_len(str_value.data, str_value.data + str_value.size); |
| 4518 | + |
| 4519 | + if (start_pos > utf8_length) { |
| 4520 | + result.append(-1); |
| 4521 | + continue; |
| 4522 | + } |
| 4523 | + |
| 4524 | + const char* search_start = skip_leading_utf8(str_value.data, str_value.data + str_value.size, start_pos - 1); |
| 4525 | + int byte_offset = search_start - str_value.data; |
| 4526 | + |
| 4527 | + int count = 0; |
| 4528 | + re2::StringPiece str_sp(str_value.data, str_value.size); |
| 4529 | + re2::StringPiece match; |
| 4530 | + |
| 4531 | + bool found = false; |
| 4532 | + while (byte_offset <= str_value.size) { |
| 4533 | + if (local_re.Match(str_sp, byte_offset, str_value.size, re2::RE2::UNANCHORED, &match, 1)) { |
| 4534 | + count++; |
| 4535 | + if (count == occurrence) { |
| 4536 | + int code_point_pos = utf8_len(str_value.data, match.data()) + 1; |
| 4537 | + result.append(code_point_pos); |
| 4538 | + found = true; |
| 4539 | + break; |
| 4540 | + } |
| 4541 | + |
| 4542 | + byte_offset = match.data() - str_value.data + match.size(); |
| 4543 | + if (match.size() == 0) { |
| 4544 | + byte_offset++; |
| 4545 | + } |
| 4546 | + } else { |
| 4547 | + break; |
| 4548 | + } |
| 4549 | + } |
| 4550 | + |
| 4551 | + if (!found) { |
| 4552 | + result.append(-1); |
| 4553 | + } |
| 4554 | + } |
| 4555 | + |
| 4556 | + return result.build(ColumnHelper::is_all_const(columns)); |
| 4557 | +} |
| 4558 | + |
| 4559 | +StatusOr<ColumnPtr> StringFunctions::regexp_position(FunctionContext* context, const Columns& columns) { |
| 4560 | + RETURN_IF_COLUMNS_ONLY_NULL(columns); |
| 4561 | + |
| 4562 | + auto* state = reinterpret_cast<StringFunctionsState*>(context->get_function_state(FunctionContext::FRAGMENT_LOCAL)); |
| 4563 | + |
| 4564 | + if (state && state->const_pattern && state->regex) { |
| 4565 | + // Const col |
| 4566 | + return regexp_position_const_pattern(state->get_or_prepare_regex(), columns); |
| 4567 | + } else { |
| 4568 | + re2::RE2::Options options; |
| 4569 | + options.set_log_errors(false); |
| 4570 | + return regexp_position_general(context, &options, columns); |
| 4571 | + } |
| 4572 | +} |
| 4573 | + |
4371 | 4574 | struct ReplaceState { |
4372 | 4575 | bool only_null{false}; |
4373 | 4576 |
|
|
0 commit comments