Skip to content

Commit 82f2462

Browse files
authored
feat: add native implementations of regexp_extract and regexp_extract_all (#4146)
1 parent d1ea99d commit 82f2462

11 files changed

Lines changed: 1315 additions & 18 deletions

File tree

docs/source/user-guide/latest/compatibility/regex.md

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,19 +29,21 @@ Comet evaluates Spark regular-expression expressions (`rlike`, `regexp_replace`,
2929
expressions fall back to Spark.
3030
- **Native (rust) engine** — the Rust [`regex`] crate, run natively with no JNI overhead. It is
3131
faster but has different semantics from Java regex (see below), so it is **opt-in per expression**
32-
via that expression's `allowIncompatible` flag. `rlike`, `regexp_replace`, and `split` have a
33-
native implementation; `regexp_extract`, `regexp_extract_all`, and `regexp_instr` do not and
34-
always run through the codegen dispatcher.
35-
36-
| SQL | Native (rust) opt-in config |
37-
| ---------------- | -------------------------------------------------------- |
38-
| `rlike` | `spark.comet.expression.RLike.allowIncompatible` |
39-
| `regexp_replace` | `spark.comet.expression.RegExpReplace.allowIncompatible` |
40-
| `split` | `spark.comet.expression.StringSplit.allowIncompatible` |
32+
via that expression's `allowIncompatible` flag. `rlike`, `regexp_replace`, `split`,
33+
`regexp_extract`, and `regexp_extract_all` have a native implementation; `regexp_instr` does not
34+
and always runs through the codegen dispatcher.
35+
36+
| SQL | Native (rust) opt-in config |
37+
| -------------------- | ----------------------------------------------------------- |
38+
| `rlike` | `spark.comet.expression.RLike.allowIncompatible` |
39+
| `regexp_replace` | `spark.comet.expression.RegExpReplace.allowIncompatible` |
40+
| `regexp_extract` | `spark.comet.expression.RegExpExtract.allowIncompatible` |
41+
| `regexp_extract_all` | `spark.comet.expression.RegExpExtractAll.allowIncompatible` |
42+
| `split` | `spark.comet.expression.StringSplit.allowIncompatible` |
4143

4244
When the native path is opted in but a case has no native implementation (for example a non-scalar
43-
`rlike` pattern, or `regexp_replace` with a non-1 offset), Comet routes that case through the
44-
codegen dispatcher.
45+
`rlike` pattern, `regexp_replace` with a non-1 offset, or `regexp_extract` with a non-literal
46+
pattern or idx), Comet routes that case through the codegen dispatcher.
4547

4648
## Disabling Comet for individual regex expressions
4749

@@ -64,7 +66,7 @@ the engine selector:
6466
| | Rust engine | Codegen dispatcher (default) |
6567
| -------------------- | ------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------- |
6668
| **Compatibility** | Differs from Java regex (see below) | 100% compatible with Spark |
67-
| **Feature coverage** | `rlike`, `regexp_replace`, `split` natively; `regexp_extract`, `regexp_extract_all`, `regexp_instr` via fallthrough | All regexp expressions (`rlike`, `regexp_extract`, `regexp_extract_all`, `regexp_instr`, `regexp_replace`, `split`) |
69+
| **Feature coverage** | `rlike`, `regexp_replace`, `split`, `regexp_extract`, `regexp_extract_all` natively; `regexp_instr` via fallthrough | All regexp expressions (`rlike`, `regexp_extract`, `regexp_extract_all`, `regexp_instr`, `regexp_replace`, `split`) |
6870
| **Performance** | Fully native, no JNI overhead | One JNI round-trip per batch (Arrow vectors stay columnar) |
6971
| **Pattern support** | Linear-time subset only | All Java regex features (backreferences, lookaround, etc.) |
7072

docs/source/user-guide/latest/expressions.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -571,8 +571,8 @@ expression-level). The `outer` variants are wired but marked `Incompatible`; the
571571
| `position` || |
572572
| `printf` || |
573573
| `regexp_count` | 🔜 | tracking [#4098](https://github.com/apache/datafusion-comet/issues/4098) |
574-
| `regexp_extract` || Routed through the JVM codegen dispatcher |
575-
| `regexp_extract_all` || Routed through the JVM codegen dispatcher |
574+
| `regexp_extract` || |
575+
| `regexp_extract_all` || |
576576
| `regexp_instr` || Routed through the JVM codegen dispatcher |
577577
| `regexp_replace` || |
578578
| `regexp_substr` | 🔜 | tracking [#4098](https://github.com/apache/datafusion-comet/issues/4098) |

native/spark-expr/src/comet_scalar_funcs.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,14 @@ pub fn create_comet_physical_fun_with_eval_mode(
198198
let func = Arc::new(crate::string_funcs::spark_split);
199199
make_comet_scalar_udf!("split", func, without data_type)
200200
}
201+
"regexp_extract" => {
202+
let func = Arc::new(crate::string_funcs::spark_regexp_extract);
203+
make_comet_scalar_udf!("regexp_extract", func, without data_type)
204+
}
205+
"regexp_extract_all" => {
206+
let func = Arc::new(crate::string_funcs::spark_regexp_extract_all);
207+
make_comet_scalar_udf!("regexp_extract_all", func, without data_type)
208+
}
201209
"get_json_object" => {
202210
let func = Arc::new(crate::string_funcs::spark_get_json_object);
203211
make_comet_scalar_udf!("get_json_object", func, without data_type)

native/spark-expr/src/string_funcs/mod.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,15 @@
1717

1818
mod contains;
1919
mod get_json_object;
20+
mod regexp_extract;
21+
mod regexp_extract_all;
22+
mod regexp_extract_common;
2023
mod split;
2124
mod substring;
2225

2326
pub use contains::SparkContains;
2427
pub use get_json_object::spark_get_json_object;
28+
pub use regexp_extract::spark_regexp_extract;
29+
pub use regexp_extract_all::spark_regexp_extract_all;
2530
pub use split::spark_split;
2631
pub use substring::SubstringExpr;
Lines changed: 288 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,288 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::array::{
19+
Array, ArrayRef, GenericStringArray, OffsetSizeTrait, StringArray, StringBuilder,
20+
};
21+
use arrow::datatypes::DataType;
22+
use datafusion::common::{
23+
cast::as_generic_string_array, exec_err, Result as DataFusionResult, ScalarValue,
24+
};
25+
use datafusion::logical_expr::ColumnarValue;
26+
use regex::Regex;
27+
use std::sync::Arc;
28+
29+
use super::regexp_extract_common::{parse_args, ParsedArgs};
30+
31+
/// Spark-compatible `regexp_extract(subject, pattern, idx)`.
32+
///
33+
/// Returns the substring of `subject` matched by group `idx` of the first match of `pattern`.
34+
/// `idx = 0` returns the entire match. Returns an empty string when there is no match or the
35+
/// matched group is unset (optional group). Returns null when any input is null. Errors when
36+
/// `idx` is out of range for the pattern's group count.
37+
///
38+
/// Note: this uses the Rust `regex` crate, whose syntax differs from Java's regex engine in
39+
/// some ways. The expression is therefore reported as Incompatible.
40+
pub fn spark_regexp_extract(args: &[ColumnarValue]) -> DataFusionResult<ColumnarValue> {
41+
let (regex, group_idx, subject) = match parse_args("regexp_extract", args)? {
42+
ParsedArgs::Parsed {
43+
regex,
44+
group_idx,
45+
subject,
46+
} => (regex, group_idx, subject),
47+
ParsedArgs::NullResult { len } => return Ok(null_result(len)),
48+
};
49+
50+
match subject {
51+
ColumnarValue::Array(array) => match array.data_type() {
52+
DataType::Utf8 => {
53+
let strings = as_generic_string_array::<i32>(array.as_ref())?;
54+
Ok(ColumnarValue::Array(extract_array(
55+
strings, &regex, group_idx,
56+
)))
57+
}
58+
DataType::LargeUtf8 => {
59+
let strings = as_generic_string_array::<i64>(array.as_ref())?;
60+
Ok(ColumnarValue::Array(extract_array(
61+
strings, &regex, group_idx,
62+
)))
63+
}
64+
other => exec_err!(
65+
"regexp_extract expects Utf8 or LargeUtf8 subject, got {:?}",
66+
other
67+
),
68+
},
69+
ColumnarValue::Scalar(ScalarValue::Utf8(s))
70+
| ColumnarValue::Scalar(ScalarValue::LargeUtf8(s)) => match s {
71+
None => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))),
72+
Some(s) => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(extract_one(
73+
s, &regex, group_idx,
74+
))))),
75+
},
76+
_ => exec_err!("regexp_extract subject must be a string"),
77+
}
78+
}
79+
80+
/// Always produces a `StringArray` (i32 offsets) regardless of the input offset width:
81+
/// Spark's `RegExpExtract.dataType` is `StringType` and the Comet serde serializes that as
82+
/// the protobuf return type, so handing back a `LargeStringArray` would be a type mismatch.
83+
/// `&str` slices are width-agnostic, so it is safe to copy them into a 32-bit-offset builder.
84+
fn extract_array<O: OffsetSizeTrait>(
85+
array: &GenericStringArray<O>,
86+
regex: &Regex,
87+
group_idx: usize,
88+
) -> ArrayRef {
89+
let mut builder = StringBuilder::with_capacity(array.len(), array.value_data().len());
90+
for i in 0..array.len() {
91+
if array.is_null(i) {
92+
builder.append_null();
93+
} else {
94+
let extracted = match regex.captures(array.value(i)) {
95+
Some(caps) => caps.get(group_idx).map(|m| m.as_str()).unwrap_or(""),
96+
None => "",
97+
};
98+
builder.append_value(extracted);
99+
}
100+
}
101+
Arc::new(builder.finish())
102+
}
103+
104+
fn extract_one(input: &str, regex: &Regex, group_idx: usize) -> String {
105+
match regex.captures(input) {
106+
Some(caps) => caps
107+
.get(group_idx)
108+
.map(|m| m.as_str().to_string())
109+
.unwrap_or_default(),
110+
None => String::new(),
111+
}
112+
}
113+
114+
fn null_result(len: Option<usize>) -> ColumnarValue {
115+
match len {
116+
Some(n) => ColumnarValue::Array(Arc::new(StringArray::new_null(n))),
117+
None => ColumnarValue::Scalar(ScalarValue::Utf8(None)),
118+
}
119+
}
120+
121+
#[cfg(test)]
122+
mod tests {
123+
use super::*;
124+
use arrow::array::{LargeStringArray, StringArray};
125+
use datafusion::common::DataFusionError;
126+
127+
fn run(args: Vec<ColumnarValue>) -> DataFusionResult<Vec<Option<String>>> {
128+
let result = spark_regexp_extract(&args)?;
129+
match result {
130+
ColumnarValue::Array(arr) => {
131+
let s = arr
132+
.as_any()
133+
.downcast_ref::<StringArray>()
134+
.expect("expected Utf8 array (regexp_extract must always return StringArray)");
135+
Ok((0..s.len())
136+
.map(|i| {
137+
if s.is_null(i) {
138+
None
139+
} else {
140+
Some(s.value(i).to_string())
141+
}
142+
})
143+
.collect())
144+
}
145+
ColumnarValue::Scalar(ScalarValue::Utf8(v)) => Ok(vec![v]),
146+
other => panic!("unexpected result: {other:?}"),
147+
}
148+
}
149+
150+
fn array(values: Vec<Option<&str>>) -> ColumnarValue {
151+
ColumnarValue::Array(Arc::new(StringArray::from(values)))
152+
}
153+
154+
fn pattern(p: &str) -> ColumnarValue {
155+
ColumnarValue::Scalar(ScalarValue::Utf8(Some(p.to_string())))
156+
}
157+
158+
fn idx(i: i32) -> ColumnarValue {
159+
ColumnarValue::Scalar(ScalarValue::Int32(Some(i)))
160+
}
161+
162+
#[test]
163+
fn basic_group_extraction() {
164+
let result = run(vec![
165+
array(vec![Some("100-200"), Some("foo-bar"), Some("nodelim")]),
166+
pattern(r"(\d+)-(\d+)"),
167+
idx(1),
168+
])
169+
.unwrap();
170+
assert_eq!(
171+
result,
172+
vec![
173+
Some("100".to_string()),
174+
Some(String::new()),
175+
Some(String::new()),
176+
]
177+
);
178+
}
179+
180+
#[test]
181+
fn idx_zero_returns_whole_match() {
182+
let result = run(vec![
183+
array(vec![Some("abc123def456")]),
184+
pattern(r"\d+"),
185+
idx(0),
186+
])
187+
.unwrap();
188+
assert_eq!(result, vec![Some("123".to_string())]);
189+
}
190+
191+
#[test]
192+
fn default_idx_is_one() {
193+
let result = run(vec![array(vec![Some("100-200")]), pattern(r"(\d+)-(\d+)")]).unwrap();
194+
assert_eq!(result, vec![Some("100".to_string())]);
195+
}
196+
197+
#[test]
198+
fn null_subject_returns_null() {
199+
let result = run(vec![
200+
array(vec![Some("a1b"), None, Some("c2d")]),
201+
pattern(r"(\d)"),
202+
idx(1),
203+
])
204+
.unwrap();
205+
assert_eq!(
206+
result,
207+
vec![Some("1".to_string()), None, Some("2".to_string())]
208+
);
209+
}
210+
211+
#[test]
212+
fn null_pattern_returns_null() {
213+
let result = run(vec![
214+
array(vec![Some("abc")]),
215+
ColumnarValue::Scalar(ScalarValue::Utf8(None)),
216+
idx(1),
217+
])
218+
.unwrap();
219+
assert_eq!(result, vec![None]);
220+
}
221+
222+
#[test]
223+
fn unmatched_optional_group_returns_empty_string() {
224+
let result = run(vec![
225+
array(vec![Some("foo")]),
226+
pattern(r"(foo)(bar)?"),
227+
idx(2),
228+
])
229+
.unwrap();
230+
assert_eq!(result, vec![Some(String::new())]);
231+
}
232+
233+
#[test]
234+
fn group_index_out_of_range_errors() {
235+
let err = spark_regexp_extract(&[array(vec![Some("abc")]), pattern(r"(a)(b)"), idx(3)])
236+
.err()
237+
.unwrap();
238+
let msg = err.to_string();
239+
assert!(msg.contains("group index"), "{msg}");
240+
assert!(msg.contains("but got 3"), "{msg}");
241+
}
242+
243+
#[test]
244+
fn negative_index_errors() {
245+
let err = spark_regexp_extract(&[array(vec![Some("abc")]), pattern(r"(a)"), idx(-1)])
246+
.err()
247+
.unwrap();
248+
let msg = err.to_string();
249+
assert!(msg.contains("group index"), "{msg}");
250+
assert!(msg.contains("but got -1"), "{msg}");
251+
}
252+
253+
#[test]
254+
fn invalid_regex_errors() {
255+
let err = spark_regexp_extract(&[array(vec![Some("abc")]), pattern(r"(unclosed"), idx(0)])
256+
.err()
257+
.unwrap();
258+
assert!(err.to_string().contains("`regexp`"));
259+
}
260+
261+
/// `LargeUtf8` subject must still produce a `StringArray` (i32 offsets) so the result type
262+
/// matches Spark's `RegExpExtract.dataType` = `StringType`. Regression for the bug where
263+
/// `extract_array::<i64>` used to build a `LargeStringArray` and trip a type mismatch.
264+
#[test]
265+
fn large_utf8_subject_returns_utf8_array() {
266+
let array = ColumnarValue::Array(Arc::new(LargeStringArray::from(vec![
267+
Some("100-200"),
268+
None,
269+
Some("foo-bar"),
270+
])));
271+
let result = spark_regexp_extract(&[array, pattern(r"(\d+)-(\d+)"), idx(1)]).unwrap();
272+
match result {
273+
ColumnarValue::Array(arr) => {
274+
arr.as_any()
275+
.downcast_ref::<StringArray>()
276+
.ok_or_else(|| {
277+
DataFusionError::Internal(format!(
278+
"expected StringArray, got {:?}",
279+
arr.data_type()
280+
))
281+
})
282+
.unwrap();
283+
assert_eq!(arr.len(), 3);
284+
}
285+
other => panic!("unexpected result: {other:?}"),
286+
}
287+
}
288+
}

0 commit comments

Comments
 (0)