From 09959a7de163a96bacdf91ff55368829d4df57af Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 29 Apr 2026 15:59:11 -0600 Subject: [PATCH 1/9] feat: support Spark expression `regexp_extract` Implement regexp_extract using the Rust regex crate. The expression is marked Incompatible because the Rust regex engine differs from the Java engine that Spark uses; users must opt in via spark.comet.expression.RegExpExtract.allowIncompatible=true. --- native/spark-expr/src/comet_scalar_funcs.rs | 4 + native/spark-expr/src/string_funcs/mod.rs | 2 + .../src/string_funcs/regexp_extract.rs | 301 ++++++++++++++++++ .../apache/comet/serde/QueryPlanSerde.scala | 1 + .../org/apache/comet/serde/strings.scala | 36 ++- .../expressions/string/regexp_extract.sql | 35 ++ .../string/regexp_extract_enabled.sql | 73 +++++ 7 files changed, 451 insertions(+), 1 deletion(-) create mode 100644 native/spark-expr/src/string_funcs/regexp_extract.rs create mode 100644 spark/src/test/resources/sql-tests/expressions/string/regexp_extract.sql create mode 100644 spark/src/test/resources/sql-tests/expressions/string/regexp_extract_enabled.sql diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index 0957868a60..9a2dd33f97 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -188,6 +188,10 @@ pub fn create_comet_physical_fun_with_eval_mode( let func = Arc::new(crate::string_funcs::spark_split); make_comet_scalar_udf!("split", func, without data_type) } + "regexp_extract" => { + let func = Arc::new(crate::string_funcs::spark_regexp_extract); + make_comet_scalar_udf!("regexp_extract", func, without data_type) + } "get_json_object" => { let func = Arc::new(crate::string_funcs::spark_get_json_object); make_comet_scalar_udf!("get_json_object", func, without data_type) diff --git a/native/spark-expr/src/string_funcs/mod.rs b/native/spark-expr/src/string_funcs/mod.rs index bb785bdb44..6655866bd3 100644 --- a/native/spark-expr/src/string_funcs/mod.rs +++ b/native/spark-expr/src/string_funcs/mod.rs @@ -17,10 +17,12 @@ mod contains; mod get_json_object; +mod regexp_extract; mod split; mod substring; pub use contains::SparkContains; pub use get_json_object::spark_get_json_object; +pub use regexp_extract::spark_regexp_extract; pub use split::spark_split; pub use substring::SubstringExpr; diff --git a/native/spark-expr/src/string_funcs/regexp_extract.rs b/native/spark-expr/src/string_funcs/regexp_extract.rs new file mode 100644 index 0000000000..7364ef72a9 --- /dev/null +++ b/native/spark-expr/src/string_funcs/regexp_extract.rs @@ -0,0 +1,301 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, ArrayRef, GenericStringArray, GenericStringBuilder}; +use arrow::datatypes::DataType; +use datafusion::common::{ + cast::as_generic_string_array, exec_err, DataFusionError, Result as DataFusionResult, + ScalarValue, +}; +use datafusion::logical_expr::ColumnarValue; +use regex::Regex; +use std::sync::Arc; + +/// Spark-compatible `regexp_extract(subject, pattern, idx)`. +/// +/// Returns the substring of `subject` matched by group `idx` of the first match of `pattern`. +/// `idx = 0` returns the entire match. Returns an empty string when there is no match or the +/// matched group is unset (optional group). Returns null when any input is null. Errors when +/// `idx` is out of range for the pattern's group count. +/// +/// Note: this uses the Rust `regex` crate, whose syntax differs from Java's regex engine in +/// some ways. The expression is therefore reported as Incompatible. +pub fn spark_regexp_extract(args: &[ColumnarValue]) -> DataFusionResult { + if args.len() < 2 || args.len() > 3 { + return exec_err!( + "regexp_extract expects 2 or 3 arguments (subject, pattern, [idx]), got {}", + args.len() + ); + } + + let idx: i32 = if args.len() == 3 { + match &args[2] { + ColumnarValue::Scalar(ScalarValue::Int32(Some(i))) => *i, + ColumnarValue::Scalar(ScalarValue::Int32(None)) => { + return Ok(null_result(subject_len(&args[0]))); + } + _ => { + return exec_err!("regexp_extract idx must be an Int32 scalar"); + } + } + } else { + 1 + }; + + if idx < 0 { + return exec_err!("regexp_extract idx must be non-negative, got {}", idx); + } + + let pattern = match &args[1] { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(p))) + | ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(p))) => p.clone(), + ColumnarValue::Scalar(ScalarValue::Utf8(None)) + | ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) => { + return Ok(null_result(subject_len(&args[0]))); + } + _ => { + return exec_err!("regexp_extract pattern must be a scalar string"); + } + }; + + let regex = Regex::new(&pattern).map_err(|e| { + DataFusionError::Execution(format!("Invalid regex pattern '{pattern}': {e}")) + })?; + + let group_count = regex.captures_len() as i32 - 1; + if idx > group_count { + return Err(DataFusionError::Execution(format!( + "Regex group count is {group_count}, but the specified group index is {idx}" + ))); + } + let group_idx = idx as usize; + + match &args[0] { + ColumnarValue::Array(array) => match array.data_type() { + DataType::Utf8 => { + let strings = as_generic_string_array::(array.as_ref())?; + Ok(ColumnarValue::Array(extract_array::( + strings, ®ex, group_idx, + ))) + } + DataType::LargeUtf8 => { + let strings = as_generic_string_array::(array.as_ref())?; + Ok(ColumnarValue::Array(extract_array::( + strings, ®ex, group_idx, + ))) + } + other => exec_err!( + "regexp_extract expects Utf8 or LargeUtf8 subject, got {:?}", + other + ), + }, + ColumnarValue::Scalar(ScalarValue::Utf8(s)) + | ColumnarValue::Scalar(ScalarValue::LargeUtf8(s)) => match s { + None => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))), + Some(s) => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some( + extract_one(s, ®ex, group_idx), + )))), + }, + _ => exec_err!("regexp_extract subject must be a string"), + } +} + +fn extract_array( + array: &GenericStringArray, + regex: &Regex, + group_idx: usize, +) -> ArrayRef { + let mut builder = GenericStringBuilder::::with_capacity(array.len(), array.value_data().len()); + for i in 0..array.len() { + if array.is_null(i) { + builder.append_null(); + } else { + builder.append_value(extract_one(array.value(i), regex, group_idx)); + } + } + Arc::new(builder.finish()) +} + +fn extract_one(input: &str, regex: &Regex, group_idx: usize) -> String { + match regex.captures(input) { + Some(caps) => caps + .get(group_idx) + .map(|m| m.as_str().to_string()) + .unwrap_or_default(), + None => String::new(), + } +} + +fn subject_len(value: &ColumnarValue) -> Option { + match value { + ColumnarValue::Array(a) => Some(a.len()), + ColumnarValue::Scalar(_) => None, + } +} + +fn null_result(len: Option) -> ColumnarValue { + match len { + Some(n) => { + let mut builder = GenericStringBuilder::::with_capacity(n, 0); + for _ in 0..n { + builder.append_null(); + } + ColumnarValue::Array(Arc::new(builder.finish())) + } + None => ColumnarValue::Scalar(ScalarValue::Utf8(None)), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::StringArray; + + fn run(args: Vec) -> DataFusionResult>> { + let result = spark_regexp_extract(&args)?; + match result { + ColumnarValue::Array(arr) => { + let s = arr + .as_any() + .downcast_ref::>() + .expect("expected Utf8 array"); + Ok((0..s.len()) + .map(|i| { + if s.is_null(i) { + None + } else { + Some(s.value(i).to_string()) + } + }) + .collect()) + } + ColumnarValue::Scalar(ScalarValue::Utf8(v)) => Ok(vec![v]), + other => panic!("unexpected result: {other:?}"), + } + } + + fn array(values: Vec>) -> ColumnarValue { + ColumnarValue::Array(Arc::new(StringArray::from(values))) + } + + fn pattern(p: &str) -> ColumnarValue { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(p.to_string()))) + } + + fn idx(i: i32) -> ColumnarValue { + ColumnarValue::Scalar(ScalarValue::Int32(Some(i))) + } + + #[test] + fn basic_group_extraction() { + let result = run(vec![ + array(vec![Some("100-200"), Some("foo-bar"), Some("nodelim")]), + pattern(r"(\d+)-(\d+)"), + idx(1), + ]) + .unwrap(); + assert_eq!( + result, + vec![ + Some("100".to_string()), + Some(String::new()), + Some(String::new()), + ] + ); + } + + #[test] + fn idx_zero_returns_whole_match() { + let result = run(vec![ + array(vec![Some("abc123def456")]), + pattern(r"\d+"), + idx(0), + ]) + .unwrap(); + assert_eq!(result, vec![Some("123".to_string())]); + } + + #[test] + fn default_idx_is_one() { + let result = run(vec![array(vec![Some("100-200")]), pattern(r"(\d+)-(\d+)")]).unwrap(); + assert_eq!(result, vec![Some("100".to_string())]); + } + + #[test] + fn null_subject_returns_null() { + let result = run(vec![ + array(vec![Some("a1b"), None, Some("c2d")]), + pattern(r"(\d)"), + idx(1), + ]) + .unwrap(); + assert_eq!( + result, + vec![Some("1".to_string()), None, Some("2".to_string())] + ); + } + + #[test] + fn null_pattern_returns_null() { + let result = run(vec![ + array(vec![Some("abc")]), + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + idx(1), + ]) + .unwrap(); + assert_eq!(result, vec![None]); + } + + #[test] + fn unmatched_optional_group_returns_empty_string() { + let result = run(vec![ + array(vec![Some("foo")]), + pattern(r"(foo)(bar)?"), + idx(2), + ]) + .unwrap(); + assert_eq!(result, vec![Some(String::new())]); + } + + #[test] + fn group_index_out_of_range_errors() { + let err = spark_regexp_extract(&[ + array(vec![Some("abc")]), + pattern(r"(a)(b)"), + idx(3), + ]) + .err() + .unwrap(); + assert!(err.to_string().contains("group count")); + } + + #[test] + fn negative_index_errors() { + let err = spark_regexp_extract(&[array(vec![Some("abc")]), pattern(r"(a)"), idx(-1)]) + .err() + .unwrap(); + assert!(err.to_string().contains("non-negative")); + } + + #[test] + fn invalid_regex_errors() { + let err = spark_regexp_extract(&[array(vec![Some("abc")]), pattern(r"(unclosed"), idx(0)]) + .err() + .unwrap(); + assert!(err.to_string().contains("Invalid regex")); + } +} diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 448c2c2cb3..515ca35525 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -170,6 +170,7 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim { classOf[Like] -> CometLike, classOf[Lower] -> CometLower, classOf[OctetLength] -> CometScalarFunction("octet_length"), + classOf[RegExpExtract] -> CometRegExpExtract, classOf[RegExpReplace] -> CometRegExpReplace, classOf[Reverse] -> CometReverse, classOf[RLike] -> CometRLike, diff --git a/spark/src/main/scala/org/apache/comet/serde/strings.scala b/spark/src/main/scala/org/apache/comet/serde/strings.scala index 968fe8cd69..cde6fedef3 100644 --- a/spark/src/main/scala/org/apache/comet/serde/strings.scala +++ b/spark/src/main/scala/org/apache/comet/serde/strings.scala @@ -21,7 +21,7 @@ package org.apache.comet.serde import java.util.Locale -import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Concat, ConcatWs, Expression, GetJsonObject, If, InitCap, IsNull, Left, Length, Like, Literal, Lower, RegExpReplace, Right, RLike, StringLPad, StringRepeat, StringRPad, StringSplit, Substring, Upper} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Concat, ConcatWs, Expression, GetJsonObject, If, InitCap, IsNull, Left, Length, Like, Literal, Lower, RegExpExtract, RegExpReplace, Right, RLike, StringLPad, StringRepeat, StringRPad, StringSplit, Substring, Upper} import org.apache.spark.sql.types.{BinaryType, DataTypes, LongType, StringType} import org.apache.spark.unsafe.types.UTF8String @@ -350,6 +350,40 @@ object CometStringLPad extends CometExpressionSerde[StringLPad] { } } +object CometRegExpExtract extends CometExpressionSerde[RegExpExtract] { + + override def getIncompatibleReasons(): Seq[String] = Seq( + "Uses Rust regexp engine, which has different behavior to Java regexp engine") + + override def getSupportLevel(expr: RegExpExtract): SupportLevel = { + if (!expr.regexp.isInstanceOf[Literal]) { + return Unsupported(Some("Only scalar regexp patterns are supported")) + } + if (!expr.idx.isInstanceOf[Literal]) { + return Unsupported(Some("idx must be an integer literal")) + } + Incompatible( + Some("Uses Rust regexp engine, which has different behavior to Java regexp engine")) + } + + override def convert( + expr: RegExpExtract, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + val subjectExpr = exprToProtoInternal(expr.subject, inputs, binding) + val patternExpr = exprToProtoInternal(expr.regexp, inputs, binding) + val idxExpr = exprToProtoInternal(expr.idx, inputs, binding) + val optExpr = scalarFunctionExprToProtoWithReturnType( + "regexp_extract", + expr.dataType, + failOnError = true, + subjectExpr, + patternExpr, + idxExpr) + optExprWithInfo(optExpr, expr, expr.subject, expr.regexp, expr.idx) + } +} + object CometRegExpReplace extends CometExpressionSerde[RegExpReplace] { override def getIncompatibleReasons(): Seq[String] = Seq( "Regexp pattern may not be compatible with Spark") diff --git a/spark/src/test/resources/sql-tests/expressions/string/regexp_extract.sql b/spark/src/test/resources/sql-tests/expressions/string/regexp_extract.sql new file mode 100644 index 0000000000..6c125b27d0 --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/string/regexp_extract.sql @@ -0,0 +1,35 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- Test regexp_extract default behaviour: Comet marks the expression Incompatible +-- (Rust regex engine differs from Java) and should fall back to Spark unless the user +-- opts in via spark.comet.expression.RegExpExtract.allowIncompatible=true. + +statement +CREATE TABLE test_regexp_extract(s string) USING parquet + +statement +INSERT INTO test_regexp_extract VALUES ('100-200'), ('abc'), (''), (NULL), ('phone 123-456-7890') + +query expect_fallback(Rust regexp engine) +SELECT regexp_extract(s, '(\\d+)-(\\d+)', 1) FROM test_regexp_extract + +query expect_fallback(Rust regexp engine) +SELECT regexp_extract(s, '(\\d+)-(\\d+)', 2) FROM test_regexp_extract + +query expect_fallback(Rust regexp engine) +SELECT regexp_extract(s, '(\\d+)-(\\d+)') FROM test_regexp_extract diff --git a/spark/src/test/resources/sql-tests/expressions/string/regexp_extract_enabled.sql b/spark/src/test/resources/sql-tests/expressions/string/regexp_extract_enabled.sql new file mode 100644 index 0000000000..70a371e132 --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/string/regexp_extract_enabled.sql @@ -0,0 +1,73 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- Test regexp_extract() with the per-expression allowIncompatible flag enabled (happy path). +-- Config: spark.comet.expression.RegExpExtract.allowIncompatible=true + +-- ConfigMatrix: parquet.enable.dictionary=false,true + +statement +CREATE TABLE test_regexp_extract_enabled(s string) USING parquet + +statement +INSERT INTO test_regexp_extract_enabled VALUES + ('100-200'), + ('foo-bar'), + ('nodelim'), + ('12-34-56'), + (''), + (NULL), + ('phone 123-456-7890') + +-- group 1 of the first match +query +SELECT regexp_extract(s, '(\\d+)-(\\d+)', 1) FROM test_regexp_extract_enabled + +-- group 2 of the first match +query +SELECT regexp_extract(s, '(\\d+)-(\\d+)', 2) FROM test_regexp_extract_enabled + +-- idx = 0 returns the entire match +query +SELECT regexp_extract(s, '(\\d+)-(\\d+)', 0) FROM test_regexp_extract_enabled + +-- default idx (no third arg) is 1 +query +SELECT regexp_extract(s, '(\\d+)-(\\d+)') FROM test_regexp_extract_enabled + +-- single-group match; no match should produce empty string, NULL input -> NULL +query +SELECT regexp_extract(s, '(\\d+)', 1) FROM test_regexp_extract_enabled + +-- optional unmatched group should return empty string +query +SELECT regexp_extract(s, '(\\w+)( \\d+)?', 2) FROM test_regexp_extract_enabled + +-- anchors and character classes +query +SELECT regexp_extract(s, '^(\\w+)', 1) FROM test_regexp_extract_enabled + +query +SELECT regexp_extract(s, '(\\d+)$', 1) FROM test_regexp_extract_enabled + +-- literal arguments +query +SELECT + regexp_extract('alice@example.com', '^([\\w.+-]+)@([\\w.-]+)$', 1), + regexp_extract('alice@example.com', '^([\\w.+-]+)@([\\w.-]+)$', 2), + regexp_extract('not-an-email', '^([\\w.+-]+)@([\\w.-]+)$', 1), + regexp_extract(NULL, '(\\d+)', 1) From ba762330e62d86025d485350c97067ab449744b1 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 29 Apr 2026 16:13:52 -0600 Subject: [PATCH 2/9] test: cover error and unicode cases for `regexp_extract` Audit follow-ups: - Align Rust error messages with Spark's `INVALID_PARAMETER_VALUE` templates so `expect_error` substrings can match both engines. - Override `getUnsupportedReasons` in `CometRegExpExtract` so the non-literal pattern and non-literal idx reasons are picked up by the Compatibility Guide generator. - Add Comet SQL test cases for: NULL pattern and NULL idx, idx=0 with no capture groups, multibyte / Unicode subjects, idx out of range, pattern with no groups + idx>=1, negative idx, invalid regex syntax, and a Java-only lookahead that Rust regex rejects (marked `ignore`). - Add fallback test cases for non-literal pattern and non-literal idx. - Mark the expression supported in `spark_expressions_support.md` with per-version audit notes. --- .../spark_expressions_support.md | 5 +- .../src/string_funcs/regexp_extract.rs | 34 +++++----- .../org/apache/comet/serde/strings.scala | 20 ++++-- .../expressions/string/regexp_extract.sql | 13 ++++ .../string/regexp_extract_enabled.sql | 62 +++++++++++++++++++ 5 files changed, 110 insertions(+), 24 deletions(-) diff --git a/docs/source/contributor-guide/spark_expressions_support.md b/docs/source/contributor-guide/spark_expressions_support.md index 1e4b4e34bc..36b0680582 100644 --- a/docs/source/contributor-guide/spark_expressions_support.md +++ b/docs/source/contributor-guide/spark_expressions_support.md @@ -439,7 +439,10 @@ - [ ] position - [ ] printf - [ ] regexp_count -- [ ] regexp_extract +- [x] regexp_extract + - Spark 3.4.3 audited 2026-04-29 (Incompatible: Rust regex engine differs from Java; `idx` out-of-range check happens at compile time in Comet vs per-row in Spark) + - Spark 3.5.8 audited 2026-04-29 (same as 3.4.3) + - Spark 4.0.1 audited 2026-04-29 (collation support added in Spark; Comet does not honour `UTF8_LCASE` and runs case-sensitively) - [ ] regexp_extract_all - [ ] regexp_instr - [ ] regexp_replace diff --git a/native/spark-expr/src/string_funcs/regexp_extract.rs b/native/spark-expr/src/string_funcs/regexp_extract.rs index 7364ef72a9..7c5546bd78 100644 --- a/native/spark-expr/src/string_funcs/regexp_extract.rs +++ b/native/spark-expr/src/string_funcs/regexp_extract.rs @@ -56,10 +56,6 @@ pub fn spark_regexp_extract(args: &[ColumnarValue]) -> DataFusionResult p.clone(), @@ -73,13 +69,16 @@ pub fn spark_regexp_extract(args: &[ColumnarValue]) -> DataFusionResult group_count { + if idx < 0 || idx > group_count { return Err(DataFusionError::Execution(format!( - "Regex group count is {group_count}, but the specified group index is {idx}" + "The value of parameter `idx` in `regexp_extract` is invalid: \ + Expects group index between 0 and {group_count}, but got {idx}." ))); } let group_idx = idx as usize; @@ -273,14 +272,13 @@ mod tests { #[test] fn group_index_out_of_range_errors() { - let err = spark_regexp_extract(&[ - array(vec![Some("abc")]), - pattern(r"(a)(b)"), - idx(3), - ]) - .err() - .unwrap(); - assert!(err.to_string().contains("group count")); + let err = + spark_regexp_extract(&[array(vec![Some("abc")]), pattern(r"(a)(b)"), idx(3)]) + .err() + .unwrap(); + let msg = err.to_string(); + assert!(msg.contains("group index"), "{msg}"); + assert!(msg.contains("but got 3"), "{msg}"); } #[test] @@ -288,7 +286,9 @@ mod tests { let err = spark_regexp_extract(&[array(vec![Some("abc")]), pattern(r"(a)"), idx(-1)]) .err() .unwrap(); - assert!(err.to_string().contains("non-negative")); + let msg = err.to_string(); + assert!(msg.contains("group index"), "{msg}"); + assert!(msg.contains("but got -1"), "{msg}"); } #[test] @@ -296,6 +296,6 @@ mod tests { let err = spark_regexp_extract(&[array(vec![Some("abc")]), pattern(r"(unclosed"), idx(0)]) .err() .unwrap(); - assert!(err.to_string().contains("Invalid regex")); + assert!(err.to_string().contains("`regexp`")); } } diff --git a/spark/src/main/scala/org/apache/comet/serde/strings.scala b/spark/src/main/scala/org/apache/comet/serde/strings.scala index cde6fedef3..0407e59840 100644 --- a/spark/src/main/scala/org/apache/comet/serde/strings.scala +++ b/spark/src/main/scala/org/apache/comet/serde/strings.scala @@ -352,18 +352,26 @@ object CometStringLPad extends CometExpressionSerde[StringLPad] { object CometRegExpExtract extends CometExpressionSerde[RegExpExtract] { - override def getIncompatibleReasons(): Seq[String] = Seq( - "Uses Rust regexp engine, which has different behavior to Java regexp engine") + private val incompatReason: String = + "Uses Rust regexp engine, which has different behavior to Java regexp engine" + private val nonLiteralPatternReason: String = + "Only scalar regexp patterns are supported" + private val nonLiteralIdxReason: String = + "idx must be an integer literal" + + override def getIncompatibleReasons(): Seq[String] = Seq(incompatReason) + + override def getUnsupportedReasons(): Seq[String] = + Seq(nonLiteralPatternReason, nonLiteralIdxReason) override def getSupportLevel(expr: RegExpExtract): SupportLevel = { if (!expr.regexp.isInstanceOf[Literal]) { - return Unsupported(Some("Only scalar regexp patterns are supported")) + return Unsupported(Some(nonLiteralPatternReason)) } if (!expr.idx.isInstanceOf[Literal]) { - return Unsupported(Some("idx must be an integer literal")) + return Unsupported(Some(nonLiteralIdxReason)) } - Incompatible( - Some("Uses Rust regexp engine, which has different behavior to Java regexp engine")) + Incompatible(Some(incompatReason)) } override def convert( diff --git a/spark/src/test/resources/sql-tests/expressions/string/regexp_extract.sql b/spark/src/test/resources/sql-tests/expressions/string/regexp_extract.sql index 6c125b27d0..ef4ac8aa78 100644 --- a/spark/src/test/resources/sql-tests/expressions/string/regexp_extract.sql +++ b/spark/src/test/resources/sql-tests/expressions/string/regexp_extract.sql @@ -33,3 +33,16 @@ SELECT regexp_extract(s, '(\\d+)-(\\d+)', 2) FROM test_regexp_extract query expect_fallback(Rust regexp engine) SELECT regexp_extract(s, '(\\d+)-(\\d+)') FROM test_regexp_extract + +-- Non-literal pattern: Comet falls back regardless of the allowIncompatible flag. +statement +CREATE TABLE test_regexp_extract_nonliteral(s string, p string, i int) USING parquet + +statement +INSERT INTO test_regexp_extract_nonliteral VALUES ('abc', '(a)(b)', 1), ('xyz', '(x)', 1) + +query expect_fallback(Only scalar regexp patterns) +SELECT regexp_extract(s, p, 1) FROM test_regexp_extract_nonliteral + +query expect_fallback(idx must be an integer literal) +SELECT regexp_extract(s, '(\\w+)', i) FROM test_regexp_extract_nonliteral diff --git a/spark/src/test/resources/sql-tests/expressions/string/regexp_extract_enabled.sql b/spark/src/test/resources/sql-tests/expressions/string/regexp_extract_enabled.sql index 70a371e132..6e30d312c0 100644 --- a/spark/src/test/resources/sql-tests/expressions/string/regexp_extract_enabled.sql +++ b/spark/src/test/resources/sql-tests/expressions/string/regexp_extract_enabled.sql @@ -71,3 +71,65 @@ SELECT regexp_extract('alice@example.com', '^([\\w.+-]+)@([\\w.-]+)$', 2), regexp_extract('not-an-email', '^([\\w.+-]+)@([\\w.-]+)$', 1), regexp_extract(NULL, '(\\d+)', 1) + +-- NULL pattern propagates as NULL (Spark and Comet both return NULL) +query +SELECT regexp_extract(s, CAST(NULL AS STRING), 1) FROM test_regexp_extract_enabled + +-- NULL idx propagates as NULL +query +SELECT regexp_extract(s, '(\\d+)-(\\d+)', CAST(NULL AS INT)) FROM test_regexp_extract_enabled + +-- idx = 0 with no capture groups returns the whole match +query +SELECT regexp_extract(s, '\\d+', 0) FROM test_regexp_extract_enabled + +-- multibyte / Unicode subject +statement +CREATE TABLE test_regexp_extract_unicode(s string) USING parquet + +statement +INSERT INTO test_regexp_extract_unicode VALUES + ('café=42'), + ('café=99'), + ('世界=1'), + ('日本=東京'), + ('🔥=hot'), + ('मानक=हिन्दी') + +-- ASCII anchors and capture groups against multibyte data +query +SELECT regexp_extract(s, '^(.+)=(.+)$', 1) FROM test_regexp_extract_unicode + +query +SELECT regexp_extract(s, '^(.+)=(.+)$', 2) FROM test_regexp_extract_unicode + +-- digit class against multibyte data +query +SELECT regexp_extract(s, '=(\\d+)$', 1) FROM test_regexp_extract_unicode + +-- ERROR CASES +-- idx > groupCount (pattern has 2 groups, ask for 3) +query expect_error(group index) +SELECT regexp_extract(s, '(\\d+)-(\\d+)', 3) FROM test_regexp_extract_enabled + +-- pattern with no capture groups but idx >= 1 +query expect_error(group index) +SELECT regexp_extract(s, '\\d+', 1) FROM test_regexp_extract_enabled + +-- negative idx +query expect_error(group index) +SELECT regexp_extract(s, '(\\d+)-(\\d+)', -1) FROM test_regexp_extract_enabled + +-- invalid regex syntax (unclosed group): both engines fail at pattern compile time. +-- Spark surfaces INVALID_PARAMETER_VALUE.PATTERN, Comet surfaces a regex parse error. +-- Both messages mention `regexp_extract`. +query expect_error(regexp_extract) +SELECT regexp_extract(s, '(unclosed', 1) FROM test_regexp_extract_enabled + +-- Java-only regex feature: lookahead. Rust regex rejects this at compile time; +-- Spark accepts it and returns "" for every row. This is one of the documented +-- incompatibilities behind the Incompatible support level, not an invariant we +-- test for cross-engine equivalence. +query ignore(Rust regex does not support lookahead, unlike Java regex) +SELECT regexp_extract(s, '(?=\\d)\\w+', 0) FROM test_regexp_extract_enabled From a143cf1e6ab65bfc2dd620ca47929535da1c3a5b Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 29 Apr 2026 16:21:28 -0600 Subject: [PATCH 3/9] refactor: simplify `regexp_extract` Rust UDF Address review feedback: - Make `extract_array` build a `GenericStringBuilder` matching the input offset size so a `LargeUtf8` subject no longer silently outputs `Utf8` (avoids potential i32-offset overflow on >2GB inputs). - Inline group extraction so the per-row `String` allocation is gone; the only remaining `to_string` is on the rare scalar code path. - Replace the manual append-null loop in `null_result` with `StringArray::new_null(n)`. - Borrow the pattern as `&str` instead of cloning it before calling `Regex::new`. - Pass `failOnError = false` to the proto, matching `CometStringSplit`. The Rust UDF does not branch on this flag, so `true` was misleading. --- .../src/string_funcs/regexp_extract.rs | 42 +++++++++---------- .../org/apache/comet/serde/strings.scala | 2 +- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/native/spark-expr/src/string_funcs/regexp_extract.rs b/native/spark-expr/src/string_funcs/regexp_extract.rs index 7c5546bd78..7ac3473adb 100644 --- a/native/spark-expr/src/string_funcs/regexp_extract.rs +++ b/native/spark-expr/src/string_funcs/regexp_extract.rs @@ -15,7 +15,9 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{Array, ArrayRef, GenericStringArray, GenericStringBuilder}; +use arrow::array::{ + Array, ArrayRef, GenericStringArray, GenericStringBuilder, OffsetSizeTrait, StringArray, +}; use arrow::datatypes::DataType; use datafusion::common::{ cast::as_generic_string_array, exec_err, DataFusionError, Result as DataFusionResult, @@ -56,9 +58,9 @@ pub fn spark_regexp_extract(args: &[ColumnarValue]) -> DataFusionResult p.clone(), + | ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(p))) => p, ColumnarValue::Scalar(ScalarValue::Utf8(None)) | ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) => { return Ok(null_result(subject_len(&args[0]))); @@ -68,7 +70,7 @@ pub fn spark_regexp_extract(args: &[ColumnarValue]) -> DataFusionResult DataFusionResult match s { None => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))), - Some(s) => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some( - extract_one(s, ®ex, group_idx), - )))), + Some(s) => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(extract_one( + s, ®ex, group_idx, + ))))), }, _ => exec_err!("regexp_extract subject must be a string"), } } -fn extract_array( +fn extract_array( array: &GenericStringArray, regex: &Regex, group_idx: usize, ) -> ArrayRef { - let mut builder = GenericStringBuilder::::with_capacity(array.len(), array.value_data().len()); + let mut builder = + GenericStringBuilder::::with_capacity(array.len(), array.value_data().len()); for i in 0..array.len() { if array.is_null(i) { builder.append_null(); } else { - builder.append_value(extract_one(array.value(i), regex, group_idx)); + let extracted = match regex.captures(array.value(i)) { + Some(caps) => caps.get(group_idx).map(|m| m.as_str()).unwrap_or(""), + None => "", + }; + builder.append_value(extracted); } } Arc::new(builder.finish()) @@ -148,13 +155,7 @@ fn subject_len(value: &ColumnarValue) -> Option { fn null_result(len: Option) -> ColumnarValue { match len { - Some(n) => { - let mut builder = GenericStringBuilder::::with_capacity(n, 0); - for _ in 0..n { - builder.append_null(); - } - ColumnarValue::Array(Arc::new(builder.finish())) - } + Some(n) => ColumnarValue::Array(Arc::new(StringArray::new_null(n))), None => ColumnarValue::Scalar(ScalarValue::Utf8(None)), } } @@ -272,10 +273,9 @@ mod tests { #[test] fn group_index_out_of_range_errors() { - let err = - spark_regexp_extract(&[array(vec![Some("abc")]), pattern(r"(a)(b)"), idx(3)]) - .err() - .unwrap(); + let err = spark_regexp_extract(&[array(vec![Some("abc")]), pattern(r"(a)(b)"), idx(3)]) + .err() + .unwrap(); let msg = err.to_string(); assert!(msg.contains("group index"), "{msg}"); assert!(msg.contains("but got 3"), "{msg}"); diff --git a/spark/src/main/scala/org/apache/comet/serde/strings.scala b/spark/src/main/scala/org/apache/comet/serde/strings.scala index 0407e59840..322977378c 100644 --- a/spark/src/main/scala/org/apache/comet/serde/strings.scala +++ b/spark/src/main/scala/org/apache/comet/serde/strings.scala @@ -384,7 +384,7 @@ object CometRegExpExtract extends CometExpressionSerde[RegExpExtract] { val optExpr = scalarFunctionExprToProtoWithReturnType( "regexp_extract", expr.dataType, - failOnError = true, + failOnError = false, subjectExpr, patternExpr, idxExpr) From 802b4d64e874fbf276046d3f7e0ce352c7b6aa16 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 12 May 2026 09:12:39 -0600 Subject: [PATCH 4/9] test: drop redundant dictionary ConfigMatrix from regexp_extract_enabled The test data has no duplicate rows, so the parquet.enable.dictionary matrix produces two identical runs. --- .../sql-tests/expressions/string/regexp_extract_enabled.sql | 2 -- 1 file changed, 2 deletions(-) diff --git a/spark/src/test/resources/sql-tests/expressions/string/regexp_extract_enabled.sql b/spark/src/test/resources/sql-tests/expressions/string/regexp_extract_enabled.sql index 6e30d312c0..7fd947d7bb 100644 --- a/spark/src/test/resources/sql-tests/expressions/string/regexp_extract_enabled.sql +++ b/spark/src/test/resources/sql-tests/expressions/string/regexp_extract_enabled.sql @@ -18,8 +18,6 @@ -- Test regexp_extract() with the per-expression allowIncompatible flag enabled (happy path). -- Config: spark.comet.expression.RegExpExtract.allowIncompatible=true --- ConfigMatrix: parquet.enable.dictionary=false,true - statement CREATE TABLE test_regexp_extract_enabled(s string) USING parquet From d24b0a35b257033d80ecc6ff4c81068e3f1fbf77 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 12 May 2026 09:12:48 -0600 Subject: [PATCH 5/9] feat: support Spark expression `regexp_extract_all` Adds a native Rust UDF `spark_regexp_extract_all` and a `CometRegExpExtractAll` serde, paralleling the existing `regexp_extract` support. Returns `List` containing the matched group across every non-overlapping match of the pattern. Reported as Incompatible because the Rust regex engine differs from Java's; gated on `spark.comet.expression.RegExpExtractAll.allowIncompatible=true`. Falls back when the pattern or `idx` is non-literal. --- .../spark_expressions_support.md | 2 +- native/spark-expr/src/comet_scalar_funcs.rs | 4 + native/spark-expr/src/string_funcs/mod.rs | 2 + .../src/string_funcs/regexp_extract_all.rs | 389 ++++++++++++++++++ .../apache/comet/serde/QueryPlanSerde.scala | 1 + .../org/apache/comet/serde/strings.scala | 44 +- .../expressions/string/regexp_extract_all.sql | 48 +++ .../string/regexp_extract_all_enabled.sql | 130 ++++++ 8 files changed, 618 insertions(+), 2 deletions(-) create mode 100644 native/spark-expr/src/string_funcs/regexp_extract_all.rs create mode 100644 spark/src/test/resources/sql-tests/expressions/string/regexp_extract_all.sql create mode 100644 spark/src/test/resources/sql-tests/expressions/string/regexp_extract_all_enabled.sql diff --git a/docs/source/contributor-guide/spark_expressions_support.md b/docs/source/contributor-guide/spark_expressions_support.md index 36b0680582..19199a3587 100644 --- a/docs/source/contributor-guide/spark_expressions_support.md +++ b/docs/source/contributor-guide/spark_expressions_support.md @@ -443,7 +443,7 @@ - Spark 3.4.3 audited 2026-04-29 (Incompatible: Rust regex engine differs from Java; `idx` out-of-range check happens at compile time in Comet vs per-row in Spark) - Spark 3.5.8 audited 2026-04-29 (same as 3.4.3) - Spark 4.0.1 audited 2026-04-29 (collation support added in Spark; Comet does not honour `UTF8_LCASE` and runs case-sensitively) -- [ ] regexp_extract_all +- [x] regexp_extract_all - [ ] regexp_instr - [ ] regexp_replace - [ ] regexp_substr diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index 9a2dd33f97..249d15a966 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -192,6 +192,10 @@ pub fn create_comet_physical_fun_with_eval_mode( let func = Arc::new(crate::string_funcs::spark_regexp_extract); make_comet_scalar_udf!("regexp_extract", func, without data_type) } + "regexp_extract_all" => { + let func = Arc::new(crate::string_funcs::spark_regexp_extract_all); + make_comet_scalar_udf!("regexp_extract_all", func, without data_type) + } "get_json_object" => { let func = Arc::new(crate::string_funcs::spark_get_json_object); make_comet_scalar_udf!("get_json_object", func, without data_type) diff --git a/native/spark-expr/src/string_funcs/mod.rs b/native/spark-expr/src/string_funcs/mod.rs index 6655866bd3..40c73f88c8 100644 --- a/native/spark-expr/src/string_funcs/mod.rs +++ b/native/spark-expr/src/string_funcs/mod.rs @@ -18,11 +18,13 @@ mod contains; mod get_json_object; mod regexp_extract; +mod regexp_extract_all; mod split; mod substring; pub use contains::SparkContains; pub use get_json_object::spark_get_json_object; pub use regexp_extract::spark_regexp_extract; +pub use regexp_extract_all::spark_regexp_extract_all; pub use split::spark_split; pub use substring::SubstringExpr; diff --git a/native/spark-expr/src/string_funcs/regexp_extract_all.rs b/native/spark-expr/src/string_funcs/regexp_extract_all.rs new file mode 100644 index 0000000000..6c5d25d161 --- /dev/null +++ b/native/spark-expr/src/string_funcs/regexp_extract_all.rs @@ -0,0 +1,389 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ + Array, ArrayBuilder, ArrayRef, BooleanBufferBuilder, GenericStringArray, ListArray, + OffsetSizeTrait, StringArray, StringBuilder, +}; +use arrow::buffer::{NullBuffer, OffsetBuffer}; +use arrow::datatypes::{DataType, Field}; +use datafusion::common::{ + cast::as_generic_string_array, exec_err, DataFusionError, Result as DataFusionResult, + ScalarValue, +}; +use datafusion::logical_expr::ColumnarValue; +use regex::Regex; +use std::sync::Arc; + +/// Spark-compatible `regexp_extract_all(subject, pattern, idx)`. +/// +/// Returns an array of all substrings of `subject` matched by group `idx` across every +/// non-overlapping match of `pattern`. `idx = 0` returns the entire match. An unmatched +/// optional group contributes the empty string. No matches yields an empty array. Returns +/// null when any input is null. Errors when `idx` is out of range for the pattern's group +/// count. +/// +/// Note: this uses the Rust `regex` crate, whose syntax differs from Java's regex engine in +/// some ways. The expression is therefore reported as Incompatible. +pub fn spark_regexp_extract_all(args: &[ColumnarValue]) -> DataFusionResult { + if args.len() < 2 || args.len() > 3 { + return exec_err!( + "regexp_extract_all expects 2 or 3 arguments (subject, pattern, [idx]), got {}", + args.len() + ); + } + + let idx: i32 = if args.len() == 3 { + match &args[2] { + ColumnarValue::Scalar(ScalarValue::Int32(Some(i))) => *i, + ColumnarValue::Scalar(ScalarValue::Int32(None)) => { + return Ok(null_result(subject_len(&args[0]))); + } + _ => { + return exec_err!("regexp_extract_all idx must be an Int32 scalar"); + } + } + } else { + 1 + }; + + let pattern: &str = match &args[1] { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(p))) + | ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(p))) => p, + ColumnarValue::Scalar(ScalarValue::Utf8(None)) + | ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) => { + return Ok(null_result(subject_len(&args[0]))); + } + _ => { + return exec_err!("regexp_extract_all pattern must be a scalar string"); + } + }; + + let regex = Regex::new(pattern).map_err(|e| { + DataFusionError::Execution(format!( + "The value of parameter `regexp` in `regexp_extract_all` is invalid: \ + '{pattern}' ({e})" + )) + })?; + + let group_count = regex.captures_len() as i32 - 1; + if idx < 0 || idx > group_count { + return Err(DataFusionError::Execution(format!( + "The value of parameter `idx` in `regexp_extract_all` is invalid: \ + Expects group index between 0 and {group_count}, but got {idx}." + ))); + } + let group_idx = idx as usize; + + match &args[0] { + ColumnarValue::Array(array) => match array.data_type() { + DataType::Utf8 => { + let strings = as_generic_string_array::(array.as_ref())?; + Ok(ColumnarValue::Array(extract_all_array( + strings, ®ex, group_idx, + ))) + } + DataType::LargeUtf8 => { + let strings = as_generic_string_array::(array.as_ref())?; + Ok(ColumnarValue::Array(extract_all_array( + strings, ®ex, group_idx, + ))) + } + other => exec_err!( + "regexp_extract_all expects Utf8 or LargeUtf8 subject, got {:?}", + other + ), + }, + ColumnarValue::Scalar(ScalarValue::Utf8(s)) + | ColumnarValue::Scalar(ScalarValue::LargeUtf8(s)) => match s { + None => Ok(ColumnarValue::Scalar(scalar_null_list())), + Some(s) => { + let matches = extract_one(s, ®ex, group_idx); + let values: Arc = Arc::new(StringArray::from(matches)); + let field = Arc::new(Field::new("item", DataType::Utf8, true)); + let offsets = OffsetBuffer::new(vec![0i32, values.len() as i32].into()); + let list = ListArray::new(field, offsets, values, None); + Ok(ColumnarValue::Scalar(ScalarValue::List(Arc::new(list)))) + } + }, + _ => exec_err!("regexp_extract_all subject must be a string"), + } +} + +fn extract_all_array( + array: &GenericStringArray, + regex: &Regex, + group_idx: usize, +) -> ArrayRef { + let mut values_builder = StringBuilder::new(); + let mut offsets: Vec = Vec::with_capacity(array.len() + 1); + let mut null_buffer = BooleanBufferBuilder::new(array.len()); + offsets.push(0); + + for i in 0..array.len() { + if array.is_null(i) { + offsets.push(values_builder.len() as i32); + null_buffer.append(false); + } else { + for caps in regex.captures_iter(array.value(i)) { + let s = caps.get(group_idx).map(|m| m.as_str()).unwrap_or(""); + values_builder.append_value(s); + } + offsets.push(values_builder.len() as i32); + null_buffer.append(true); + } + } + + let values = Arc::new(values_builder.finish()) as ArrayRef; + let field = Arc::new(Field::new("item", DataType::Utf8, true)); + let nulls = NullBuffer::new(null_buffer.finish()); + Arc::new(ListArray::new( + field, + OffsetBuffer::new(offsets.into()), + values, + Some(nulls), + )) +} + +fn extract_one(input: &str, regex: &Regex, group_idx: usize) -> Vec { + regex + .captures_iter(input) + .map(|caps| { + caps.get(group_idx) + .map(|m| m.as_str().to_string()) + .unwrap_or_default() + }) + .collect() +} + +fn subject_len(value: &ColumnarValue) -> Option { + match value { + ColumnarValue::Array(a) => Some(a.len()), + ColumnarValue::Scalar(_) => None, + } +} + +fn null_result(len: Option) -> ColumnarValue { + match len { + Some(n) => ColumnarValue::Array(null_list_array(n)), + None => ColumnarValue::Scalar(scalar_null_list()), + } +} + +fn null_list_array(len: usize) -> ArrayRef { + let field = Arc::new(Field::new("item", DataType::Utf8, true)); + let values = Arc::new(StringArray::from(Vec::<&str>::new())) as ArrayRef; + let offsets = OffsetBuffer::new(vec![0i32; len + 1].into()); + let nulls = NullBuffer::new_null(len); + Arc::new(ListArray::new(field, offsets, values, Some(nulls))) +} + +fn scalar_null_list() -> ScalarValue { + let field = Arc::new(Field::new("item", DataType::Utf8, true)); + let values = Arc::new(StringArray::from(Vec::<&str>::new())) as ArrayRef; + let offsets = OffsetBuffer::new(vec![0i32, 0].into()); + let nulls = NullBuffer::new_null(1); + ScalarValue::List(Arc::new(ListArray::new( + field, + offsets, + values, + Some(nulls), + ))) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::StringArray; + + fn run(args: Vec) -> DataFusionResult>>> { + let result = spark_regexp_extract_all(&args)?; + let list = match result { + ColumnarValue::Array(arr) => arr, + ColumnarValue::Scalar(ScalarValue::List(arr)) => arr as ArrayRef, + other => panic!("unexpected result: {other:?}"), + }; + let list = list + .as_any() + .downcast_ref::() + .expect("expected ListArray"); + Ok((0..list.len()) + .map(|i| { + if list.is_null(i) { + None + } else { + let inner = list.value(i); + let strs = inner + .as_any() + .downcast_ref::() + .expect("expected inner StringArray"); + Some((0..strs.len()).map(|j| strs.value(j).to_string()).collect()) + } + }) + .collect()) + } + + fn array(values: Vec>) -> ColumnarValue { + ColumnarValue::Array(Arc::new(StringArray::from(values))) + } + + fn pattern(p: &str) -> ColumnarValue { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(p.to_string()))) + } + + fn idx(i: i32) -> ColumnarValue { + ColumnarValue::Scalar(ScalarValue::Int32(Some(i))) + } + + #[test] + fn basic_group_extraction() { + let result = run(vec![ + array(vec![ + Some("100-200, 300-400"), + Some("foo-bar"), + Some("nodelim"), + ]), + pattern(r"(\d+)-(\d+)"), + idx(1), + ]) + .unwrap(); + assert_eq!( + result, + vec![ + Some(vec!["100".to_string(), "300".to_string()]), + Some(vec![]), + Some(vec![]), + ] + ); + } + + #[test] + fn second_group() { + let result = run(vec![ + array(vec![Some("100-200, 300-400")]), + pattern(r"(\d+)-(\d+)"), + idx(2), + ]) + .unwrap(); + assert_eq!( + result, + vec![Some(vec!["200".to_string(), "400".to_string()])] + ); + } + + #[test] + fn idx_zero_returns_whole_matches() { + let result = run(vec![ + array(vec![Some("abc123def456")]), + pattern(r"\d+"), + idx(0), + ]) + .unwrap(); + assert_eq!( + result, + vec![Some(vec!["123".to_string(), "456".to_string()])] + ); + } + + #[test] + fn default_idx_is_one() { + let result = run(vec![ + array(vec![Some("100-200, 300-400")]), + pattern(r"(\d+)-(\d+)"), + ]) + .unwrap(); + assert_eq!( + result, + vec![Some(vec!["100".to_string(), "300".to_string()])] + ); + } + + #[test] + fn no_match_returns_empty_array() { + let result = run(vec![array(vec![Some("abc")]), pattern(r"(\d+)"), idx(1)]).unwrap(); + assert_eq!(result, vec![Some(vec![])]); + } + + #[test] + fn null_subject_returns_null() { + let result = run(vec![ + array(vec![Some("1 2 3"), None, Some("4 5")]), + pattern(r"(\d)"), + idx(1), + ]) + .unwrap(); + assert_eq!( + result, + vec![ + Some(vec!["1".to_string(), "2".to_string(), "3".to_string()]), + None, + Some(vec!["4".to_string(), "5".to_string()]), + ] + ); + } + + #[test] + fn null_pattern_returns_null() { + let result = run(vec![ + array(vec![Some("abc")]), + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + idx(1), + ]) + .unwrap(); + assert_eq!(result, vec![None]); + } + + #[test] + fn unmatched_optional_group_returns_empty_string() { + let result = run(vec![ + array(vec![Some("foo foo")]), + pattern(r"(foo)(bar)?"), + idx(2), + ]) + .unwrap(); + assert_eq!(result, vec![Some(vec![String::new(), String::new()])]); + } + + #[test] + fn group_index_out_of_range_errors() { + let err = spark_regexp_extract_all(&[array(vec![Some("abc")]), pattern(r"(a)(b)"), idx(3)]) + .err() + .unwrap(); + let msg = err.to_string(); + assert!(msg.contains("group index"), "{msg}"); + assert!(msg.contains("but got 3"), "{msg}"); + } + + #[test] + fn negative_index_errors() { + let err = spark_regexp_extract_all(&[array(vec![Some("abc")]), pattern(r"(a)"), idx(-1)]) + .err() + .unwrap(); + let msg = err.to_string(); + assert!(msg.contains("group index"), "{msg}"); + assert!(msg.contains("but got -1"), "{msg}"); + } + + #[test] + fn invalid_regex_errors() { + let err = + spark_regexp_extract_all(&[array(vec![Some("abc")]), pattern(r"(unclosed"), idx(0)]) + .err() + .unwrap(); + assert!(err.to_string().contains("`regexp`")); + } +} diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 515ca35525..71b691c0b7 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -171,6 +171,7 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim { classOf[Lower] -> CometLower, classOf[OctetLength] -> CometScalarFunction("octet_length"), classOf[RegExpExtract] -> CometRegExpExtract, + classOf[RegExpExtractAll] -> CometRegExpExtractAll, classOf[RegExpReplace] -> CometRegExpReplace, classOf[Reverse] -> CometReverse, classOf[RLike] -> CometRLike, diff --git a/spark/src/main/scala/org/apache/comet/serde/strings.scala b/spark/src/main/scala/org/apache/comet/serde/strings.scala index 322977378c..120a961ca5 100644 --- a/spark/src/main/scala/org/apache/comet/serde/strings.scala +++ b/spark/src/main/scala/org/apache/comet/serde/strings.scala @@ -21,7 +21,7 @@ package org.apache.comet.serde import java.util.Locale -import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Concat, ConcatWs, Expression, GetJsonObject, If, InitCap, IsNull, Left, Length, Like, Literal, Lower, RegExpExtract, RegExpReplace, Right, RLike, StringLPad, StringRepeat, StringRPad, StringSplit, Substring, Upper} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Concat, ConcatWs, Expression, GetJsonObject, If, InitCap, IsNull, Left, Length, Like, Literal, Lower, RegExpExtract, RegExpExtractAll, RegExpReplace, Right, RLike, StringLPad, StringRepeat, StringRPad, StringSplit, Substring, Upper} import org.apache.spark.sql.types.{BinaryType, DataTypes, LongType, StringType} import org.apache.spark.unsafe.types.UTF8String @@ -392,6 +392,48 @@ object CometRegExpExtract extends CometExpressionSerde[RegExpExtract] { } } +object CometRegExpExtractAll extends CometExpressionSerde[RegExpExtractAll] { + + private val incompatReason: String = + "Uses Rust regexp engine, which has different behavior to Java regexp engine" + private val nonLiteralPatternReason: String = + "Only scalar regexp patterns are supported" + private val nonLiteralIdxReason: String = + "idx must be an integer literal" + + override def getIncompatibleReasons(): Seq[String] = Seq(incompatReason) + + override def getUnsupportedReasons(): Seq[String] = + Seq(nonLiteralPatternReason, nonLiteralIdxReason) + + override def getSupportLevel(expr: RegExpExtractAll): SupportLevel = { + if (!expr.regexp.isInstanceOf[Literal]) { + return Unsupported(Some(nonLiteralPatternReason)) + } + if (!expr.idx.isInstanceOf[Literal]) { + return Unsupported(Some(nonLiteralIdxReason)) + } + Incompatible(Some(incompatReason)) + } + + override def convert( + expr: RegExpExtractAll, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + val subjectExpr = exprToProtoInternal(expr.subject, inputs, binding) + val patternExpr = exprToProtoInternal(expr.regexp, inputs, binding) + val idxExpr = exprToProtoInternal(expr.idx, inputs, binding) + val optExpr = scalarFunctionExprToProtoWithReturnType( + "regexp_extract_all", + expr.dataType, + failOnError = false, + subjectExpr, + patternExpr, + idxExpr) + optExprWithInfo(optExpr, expr, expr.subject, expr.regexp, expr.idx) + } +} + object CometRegExpReplace extends CometExpressionSerde[RegExpReplace] { override def getIncompatibleReasons(): Seq[String] = Seq( "Regexp pattern may not be compatible with Spark") diff --git a/spark/src/test/resources/sql-tests/expressions/string/regexp_extract_all.sql b/spark/src/test/resources/sql-tests/expressions/string/regexp_extract_all.sql new file mode 100644 index 0000000000..b212990053 --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/string/regexp_extract_all.sql @@ -0,0 +1,48 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- Test regexp_extract_all default behaviour: Comet marks the expression Incompatible +-- (Rust regex engine differs from Java) and should fall back to Spark unless the user +-- opts in via spark.comet.expression.RegExpExtractAll.allowIncompatible=true. + +statement +CREATE TABLE test_regexp_extract_all(s string) USING parquet + +statement +INSERT INTO test_regexp_extract_all VALUES ('100-200, 300-400'), ('abc'), (''), (NULL), ('phone 123-456-7890') + +query expect_fallback(Rust regexp engine) +SELECT regexp_extract_all(s, '(\\d+)-(\\d+)', 1) FROM test_regexp_extract_all + +query expect_fallback(Rust regexp engine) +SELECT regexp_extract_all(s, '(\\d+)-(\\d+)', 2) FROM test_regexp_extract_all + +query expect_fallback(Rust regexp engine) +SELECT regexp_extract_all(s, '(\\d+)-(\\d+)') FROM test_regexp_extract_all + +-- Non-literal pattern: Comet falls back regardless of the allowIncompatible flag. +statement +CREATE TABLE test_regexp_extract_all_nonliteral(s string, p string, i int) USING parquet + +statement +INSERT INTO test_regexp_extract_all_nonliteral VALUES ('abc', '(a)(b)', 1), ('xyz', '(x)', 1) + +query expect_fallback(Only scalar regexp patterns) +SELECT regexp_extract_all(s, p, 1) FROM test_regexp_extract_all_nonliteral + +query expect_fallback(idx must be an integer literal) +SELECT regexp_extract_all(s, '(\\w+)', i) FROM test_regexp_extract_all_nonliteral diff --git a/spark/src/test/resources/sql-tests/expressions/string/regexp_extract_all_enabled.sql b/spark/src/test/resources/sql-tests/expressions/string/regexp_extract_all_enabled.sql new file mode 100644 index 0000000000..3c46cb2389 --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/string/regexp_extract_all_enabled.sql @@ -0,0 +1,130 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- Test regexp_extract_all() with the per-expression allowIncompatible flag enabled (happy path). +-- Config: spark.comet.expression.RegExpExtractAll.allowIncompatible=true + +statement +CREATE TABLE test_regexp_extract_all_enabled(s string) USING parquet + +statement +INSERT INTO test_regexp_extract_all_enabled VALUES + ('100-200, 300-400'), + ('foo-bar'), + ('nodelim'), + ('12-34-56'), + (''), + (NULL), + ('phone 123-456-7890') + +-- group 1 across every match +query +SELECT regexp_extract_all(s, '(\\d+)-(\\d+)', 1) FROM test_regexp_extract_all_enabled + +-- group 2 across every match +query +SELECT regexp_extract_all(s, '(\\d+)-(\\d+)', 2) FROM test_regexp_extract_all_enabled + +-- idx = 0 returns every entire match +query +SELECT regexp_extract_all(s, '(\\d+)-(\\d+)', 0) FROM test_regexp_extract_all_enabled + +-- default idx (no third arg) is 1 +query +SELECT regexp_extract_all(s, '(\\d+)-(\\d+)') FROM test_regexp_extract_all_enabled + +-- single-group match; no match should produce empty array, NULL input -> NULL +query +SELECT regexp_extract_all(s, '(\\d+)', 1) FROM test_regexp_extract_all_enabled + +-- optional unmatched group should contribute the empty string +query +SELECT regexp_extract_all(s, '(\\w+)( \\d+)?', 2) FROM test_regexp_extract_all_enabled + +-- anchors and character classes +query +SELECT regexp_extract_all(s, '^(\\w+)', 1) FROM test_regexp_extract_all_enabled + +query +SELECT regexp_extract_all(s, '(\\d+)$', 1) FROM test_regexp_extract_all_enabled + +-- literal arguments +query +SELECT + regexp_extract_all('alice@example.com, bob@example.org', '([\\w.+-]+)@([\\w.-]+)', 1), + regexp_extract_all('alice@example.com, bob@example.org', '([\\w.+-]+)@([\\w.-]+)', 2), + regexp_extract_all('not-an-email', '([\\w.+-]+)@([\\w.-]+)', 1), + regexp_extract_all(NULL, '(\\d+)', 1) + +-- NULL pattern propagates as NULL (Spark and Comet both return NULL) +query +SELECT regexp_extract_all(s, CAST(NULL AS STRING), 1) FROM test_regexp_extract_all_enabled + +-- NULL idx propagates as NULL +query +SELECT regexp_extract_all(s, '(\\d+)-(\\d+)', CAST(NULL AS INT)) FROM test_regexp_extract_all_enabled + +-- idx = 0 with no capture groups returns every whole match +query +SELECT regexp_extract_all(s, '\\d+', 0) FROM test_regexp_extract_all_enabled + +-- multibyte / Unicode subject +statement +CREATE TABLE test_regexp_extract_all_unicode(s string) USING parquet + +statement +INSERT INTO test_regexp_extract_all_unicode VALUES + ('café=42, hot=99'), + ('世界=1, 東京=2'), + ('🔥=hot, ❄=cold'), + ('मानक=हिन्दी') + +-- ASCII anchors and capture groups against multibyte data +query +SELECT regexp_extract_all(s, '(\\S+)=(\\S+)', 1) FROM test_regexp_extract_all_unicode + +query +SELECT regexp_extract_all(s, '(\\S+)=(\\S+)', 2) FROM test_regexp_extract_all_unicode + +-- digit class against multibyte data +query +SELECT regexp_extract_all(s, '=(\\d+)', 1) FROM test_regexp_extract_all_unicode + +-- ERROR CASES +-- idx > groupCount (pattern has 2 groups, ask for 3) +query expect_error(group index) +SELECT regexp_extract_all(s, '(\\d+)-(\\d+)', 3) FROM test_regexp_extract_all_enabled + +-- pattern with no capture groups but idx >= 1 +query expect_error(group index) +SELECT regexp_extract_all(s, '\\d+', 1) FROM test_regexp_extract_all_enabled + +-- negative idx +query expect_error(group index) +SELECT regexp_extract_all(s, '(\\d+)-(\\d+)', -1) FROM test_regexp_extract_all_enabled + +-- invalid regex syntax (unclosed group): both engines fail at pattern compile time. +-- Spark surfaces INVALID_PARAMETER_VALUE.PATTERN, Comet surfaces a regex parse error. +-- Both messages mention `regexp_extract_all`. +query expect_error(regexp_extract_all) +SELECT regexp_extract_all(s, '(unclosed', 1) FROM test_regexp_extract_all_enabled + +-- Java-only regex feature: lookahead. Rust regex rejects this at compile time; +-- Spark accepts it. This is one of the documented incompatibilities behind the +-- Incompatible support level, not an invariant we test for cross-engine equivalence. +query ignore(Rust regex does not support lookahead, unlike Java regex) +SELECT regexp_extract_all(s, '(?=\\d)\\w+', 0) FROM test_regexp_extract_all_enabled From c20ec7c89ec520cba22f59a5bc3e9c5118ee626b Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 10 Jun 2026 08:45:13 -0600 Subject: [PATCH 6/9] chore: add microbenchmark for regexp_extract / regexp_extract_all Mirrors CometRegExpBenchmark and exercises four execution modes per pattern: Spark, Comet (Scan only), Comet (Exec, native Rust regex via allowIncompatible), and Comet (Exec, JVM regex via codegen dispatcher). Patterns avoid Java-only constructs so the native path is actually benchmarked rather than falling through to the dispatcher. --- .../CometRegExpExtractBenchmark.scala | 136 ++++++++++++++++++ 1 file changed, 136 insertions(+) create mode 100644 spark/src/test/scala/org/apache/spark/sql/benchmark/CometRegExpExtractBenchmark.scala diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometRegExpExtractBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometRegExpExtractBenchmark.scala new file mode 100644 index 0000000000..538ea5bd42 --- /dev/null +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometRegExpExtractBenchmark.scala @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.benchmark + +import org.apache.spark.benchmark.Benchmark + +import org.apache.comet.CometConf + +/** + * Configuration for a single regexp_extract pattern under benchmark. + * + * @param name + * short label for the pattern + * @param pattern + * the regex literal supplied to regexp_extract / regexp_extract_all + * @param idx + * capture-group index (0 for the whole match, 1 for the first group, ...) + */ +case class RegExpExtractPattern(name: String, pattern: String, idx: Int) + +/** + * Benchmark `regexp_extract` and `regexp_extract_all` across all execution modes: + * - Spark + * - Comet (Scan only) + * - Comet (Scan + Exec, native Rust regex) + * - Comet (Scan + Exec, JVM-side java.util.regex via codegen dispatcher) + * + * To run: + * {{{ + * SPARK_GENERATE_BENCHMARK_FILES=1 \ + * make benchmark-org.apache.spark.sql.benchmark.CometRegExpExtractBenchmark + * }}} + * + * Results land in `spark/benchmarks/CometRegExpExtractBenchmark-**results.txt`. + */ +object CometRegExpExtractBenchmark extends CometBenchmarkBase { + + // Patterns chosen to span common shapes that both engines accept. Avoid Java-only constructs + // (backreferences, lookaround, possessive quantifiers, embedded flags) so the native (Rust) + // path is actually exercised rather than falling through to the codegen dispatcher. + private val patterns = List( + RegExpExtractPattern("single_group", "([0-9]+)", 1), + RegExpExtractPattern("two_groups_first", "([a-z]+)([0-9]+)", 1), + RegExpExtractPattern("two_groups_second", "([a-z]+)([0-9]+)", 2), + RegExpExtractPattern("whole_match", "[a-z]+[0-9]+", 0), + RegExpExtractPattern("anchored", "^([0-9]+)", 1), + RegExpExtractPattern("alternation", "(abc|def|ghi)", 1)) + + override def runCometBenchmark(mainArgs: Array[String]): Unit = { + runBenchmarkWithTable("regexp_extract modes", 1024) { v => + withTempPath { dir => + withTempTable("parquetV1Table") { + prepareTable( + dir, + spark.sql(s"SELECT REPEAT(CAST(value AS STRING), 10) AS c1 FROM $tbl")) + + patterns.foreach { p => + val extractQuery = + s"select regexp_extract(c1, '${p.pattern}', ${p.idx}) from parquetV1Table" + runBenchmark(s"regexp_extract / ${p.name}") { + runModes("RegExpExtract", p.name, v, extractQuery) + } + + val extractAllQuery = + s"select regexp_extract_all(c1, '${p.pattern}', ${p.idx}) from parquetV1Table" + runBenchmark(s"regexp_extract_all / ${p.name}") { + runModes("RegExpExtractAll", p.name, v, extractAllQuery) + } + } + } + } + } + } + + /** Runs all four modes for a single regexp_extract / regexp_extract_all query. */ + private def runModes( + exprClassName: String, + name: String, + cardinality: Long, + query: String): Unit = { + val benchmark = new Benchmark(name, cardinality, output = output) + + benchmark.addCase("Spark") { _ => + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + spark.sql(query).noop() + } + } + + benchmark.addCase("Comet (Scan)") { _ => + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "false") { + spark.sql(query).noop() + } + } + + val baseExec = Map( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + "spark.sql.optimizer.constantFolding.enabled" -> "false") + + benchmark.addCase("Comet (Exec, native Rust regex)") { _ => + val configs = + baseExec ++ Map(CometConf.getExprAllowIncompatConfigKey(exprClassName) -> "true") + withSQLConf(configs.toSeq: _*) { + spark.sql(query).noop() + } + } + + benchmark.addCase("Comet (Exec, JVM regex)") { _ => + // The codegen dispatcher is enabled by default, so no extra config is needed. + withSQLConf(baseExec.toSeq: _*) { + spark.sql(query).noop() + } + } + + benchmark.run() + } +} From 1de8c0ffc2af7104251f3f313aff44e65b90ceb2 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 10 Jun 2026 10:39:01 -0600 Subject: [PATCH 7/9] chore: bump regexp_extract benchmark cardinality to 1M, fix labels and data - Increase rows from 1024 to 1024 * 1024 so each case runs long enough that per-batch overhead doesn't dominate. - Pass the prefixed name (e.g. "regexp_extract / single_group") to the inner Benchmark so the output table identifies which expression was tested. - Switch the synthetic subject from REPEAT(digits, 10) to CONCAT('abc', long_value, 'def') so every pattern in the matrix actually finds a match instead of failing fast on no-match. --- .../CometRegExpExtractBenchmark.scala | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometRegExpExtractBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometRegExpExtractBenchmark.scala index 538ea5bd42..e00063d82d 100644 --- a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometRegExpExtractBenchmark.scala +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometRegExpExtractBenchmark.scala @@ -64,24 +64,31 @@ object CometRegExpExtractBenchmark extends CometBenchmarkBase { RegExpExtractPattern("alternation", "(abc|def|ghi)", 1)) override def runCometBenchmark(mainArgs: Array[String]): Unit = { - runBenchmarkWithTable("regexp_extract modes", 1024) { v => + runBenchmarkWithTable("regexp_extract modes", 1024 * 1024) { v => withTempPath { dir => withTempTable("parquetV1Table") { + // Build a realistic alphanumeric subject so the patterns actually match. The + // underlying `tbl` view holds Long values (some negative, some positive); we + // sandwich them between two short letter runs so every pattern in `patterns` + // (single digit group, `[a-z]+[0-9]+` whole match, alternation against `abc`, + // etc.) finds a match. prepareTable( dir, - spark.sql(s"SELECT REPEAT(CAST(value AS STRING), 10) AS c1 FROM $tbl")) + spark.sql(s"SELECT CONCAT('abc', CAST(ABS(value) AS STRING), 'def') AS c1 FROM $tbl")) patterns.foreach { p => + val extractName = s"regexp_extract / ${p.name}" val extractQuery = s"select regexp_extract(c1, '${p.pattern}', ${p.idx}) from parquetV1Table" - runBenchmark(s"regexp_extract / ${p.name}") { - runModes("RegExpExtract", p.name, v, extractQuery) + runBenchmark(extractName) { + runModes("RegExpExtract", extractName, v, extractQuery) } + val extractAllName = s"regexp_extract_all / ${p.name}" val extractAllQuery = s"select regexp_extract_all(c1, '${p.pattern}', ${p.idx}) from parquetV1Table" - runBenchmark(s"regexp_extract_all / ${p.name}") { - runModes("RegExpExtractAll", p.name, v, extractAllQuery) + runBenchmark(extractAllName) { + runModes("RegExpExtractAll", extractAllName, v, extractAllQuery) } } } From ebd011c90e0d2fa05d04dc7999a3352cb1212836 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 10 Jun 2026 11:11:18 -0600 Subject: [PATCH 8/9] chore: force extension registration in regexp_extract benchmark CometBenchmarkBase wires CometSparkSessionExtensions via `withExtensions`, but that call is silently dropped when `SparkSession.builder.getOrCreate()` returns an existing session, so the benchmark was running plain Spark in all four "modes" -- the EXPLAIN plan was just `Project + ColumnarToRow + FileScan parquet` with no CometScan or CometProject. Override `getSparkSession` to set `spark.sql.extensions` on the SparkConf (plus the off-heap and shuffle-manager configs CometTestBase uses) so Comet planning rules actually fire. The native Rust mode now shows up to 2.5x over Spark on patterns with many matches (e.g. regexp_extract_all / alternation), and 1.2-1.3x on the simpler shapes. --- .../CometRegExpExtractBenchmark.scala | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometRegExpExtractBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometRegExpExtractBenchmark.scala index e00063d82d..1d04ae49df 100644 --- a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometRegExpExtractBenchmark.scala +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometRegExpExtractBenchmark.scala @@ -19,7 +19,10 @@ package org.apache.spark.sql.benchmark +import org.apache.spark.SparkConf import org.apache.spark.benchmark.Benchmark +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.internal.SQLConf import org.apache.comet.CometConf @@ -52,6 +55,34 @@ case class RegExpExtractPattern(name: String, pattern: String, idx: Int) */ object CometRegExpExtractBenchmark extends CometBenchmarkBase { + // CometBenchmarkBase wires `CometSparkSessionExtensions` via `withExtensions`, but that call + // is silently dropped when `SparkSession.builder.getOrCreate()` returns an existing session + // (the `SqlBasedBenchmark.spark` field can construct one before the override runs). Setting + // `spark.sql.extensions` on the SparkConf forces extension registration regardless. The + // off-heap and shuffle-manager configs match what CometTestBase sets so Comet's planning + // rules don't bail out early. + override def getSparkSession: SparkSession = { + val conf = new SparkConf() + .setAppName("CometRegExpExtractBenchmark") + .set("spark.master", "local[1]") + .setIfMissing("spark.driver.memory", "3g") + .setIfMissing("spark.executor.memory", "3g") + .set("spark.sql.extensions", "org.apache.comet.CometSparkSessionExtensions") + .set("spark.memory.offHeap.enabled", "true") + .set("spark.memory.offHeap.size", "2g") + .set( + "spark.shuffle.manager", + "org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager") + + val sparkSession = SparkSession.builder.config(conf).getOrCreate() + sparkSession.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, "true") + sparkSession.conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true") + sparkSession.conf.set(CometConf.COMET_ENABLED.key, "false") + sparkSession.conf.set(CometConf.COMET_EXEC_ENABLED.key, "false") + sparkSession.conf.set(SQLConf.ANSI_ENABLED.key, "false") + sparkSession + } + // Patterns chosen to span common shapes that both engines accept. Avoid Java-only constructs // (backreferences, lookaround, possessive quantifiers, embedded flags) so the native (Rust) // path is actually exercised rather than falling through to the codegen dispatcher. From 367e07cd31713dc69af99f187a07d53e31599f86 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 10 Jun 2026 11:23:04 -0600 Subject: [PATCH 9/9] fix: regexp_extract always return Utf8, dedup arg parsing Two issues raised in PR review: 1. extract_array uses `GenericStringBuilder`, so a LargeUtf8 input produces a LargeStringArray. The Comet serde declares the protobuf return type as `expr.dataType` (StringType / Utf8) for both regexp_extract and regexp_extract_all, so handing back i64-offset data would be a type mismatch. Always build a `StringArray` (i32 offsets) -- `&str` slices are width-agnostic so the copy is safe. Same fix for regexp_extract_all's inner value array. 2. The two UDFs duplicated arg-count check, idx parsing, pattern parsing, regex compile, and group-index validation verbatim. Lift that into a shared `parse_args` helper in `regexp_extract_common.rs` returning `ParsedArgs::Parsed { regex, group_idx, subject }` or `ParsedArgs::NullResult { len }`; each UDF then translates the null short-circuit into its own per-function shape (Utf8 null vs ListArray of nulls). Adds two new unit tests, one per UDF, locking down that LargeUtf8 subjects still produce a 32-bit-offset StringArray. --- native/spark-expr/src/string_funcs/mod.rs | 1 + .../src/string_funcs/regexp_extract.rs | 115 ++++++++---------- .../src/string_funcs/regexp_extract_all.rs | 100 +++++++-------- .../src/string_funcs/regexp_extract_common.rs | 112 +++++++++++++++++ 4 files changed, 206 insertions(+), 122 deletions(-) create mode 100644 native/spark-expr/src/string_funcs/regexp_extract_common.rs diff --git a/native/spark-expr/src/string_funcs/mod.rs b/native/spark-expr/src/string_funcs/mod.rs index 40c73f88c8..ce1b8009a8 100644 --- a/native/spark-expr/src/string_funcs/mod.rs +++ b/native/spark-expr/src/string_funcs/mod.rs @@ -19,6 +19,7 @@ mod contains; mod get_json_object; mod regexp_extract; mod regexp_extract_all; +mod regexp_extract_common; mod split; mod substring; diff --git a/native/spark-expr/src/string_funcs/regexp_extract.rs b/native/spark-expr/src/string_funcs/regexp_extract.rs index 7ac3473adb..156241d6a1 100644 --- a/native/spark-expr/src/string_funcs/regexp_extract.rs +++ b/native/spark-expr/src/string_funcs/regexp_extract.rs @@ -16,17 +16,18 @@ // under the License. use arrow::array::{ - Array, ArrayRef, GenericStringArray, GenericStringBuilder, OffsetSizeTrait, StringArray, + Array, ArrayRef, GenericStringArray, OffsetSizeTrait, StringArray, StringBuilder, }; use arrow::datatypes::DataType; use datafusion::common::{ - cast::as_generic_string_array, exec_err, DataFusionError, Result as DataFusionResult, - ScalarValue, + cast::as_generic_string_array, exec_err, Result as DataFusionResult, ScalarValue, }; use datafusion::logical_expr::ColumnarValue; use regex::Regex; use std::sync::Arc; +use super::regexp_extract_common::{parse_args, ParsedArgs}; + /// Spark-compatible `regexp_extract(subject, pattern, idx)`. /// /// Returns the substring of `subject` matched by group `idx` of the first match of `pattern`. @@ -37,65 +38,26 @@ use std::sync::Arc; /// Note: this uses the Rust `regex` crate, whose syntax differs from Java's regex engine in /// some ways. The expression is therefore reported as Incompatible. pub fn spark_regexp_extract(args: &[ColumnarValue]) -> DataFusionResult { - if args.len() < 2 || args.len() > 3 { - return exec_err!( - "regexp_extract expects 2 or 3 arguments (subject, pattern, [idx]), got {}", - args.len() - ); - } - - let idx: i32 = if args.len() == 3 { - match &args[2] { - ColumnarValue::Scalar(ScalarValue::Int32(Some(i))) => *i, - ColumnarValue::Scalar(ScalarValue::Int32(None)) => { - return Ok(null_result(subject_len(&args[0]))); - } - _ => { - return exec_err!("regexp_extract idx must be an Int32 scalar"); - } - } - } else { - 1 + let (regex, group_idx, subject) = match parse_args("regexp_extract", args)? { + ParsedArgs::Parsed { + regex, + group_idx, + subject, + } => (regex, group_idx, subject), + ParsedArgs::NullResult { len } => return Ok(null_result(len)), }; - let pattern: &str = match &args[1] { - ColumnarValue::Scalar(ScalarValue::Utf8(Some(p))) - | ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(p))) => p, - ColumnarValue::Scalar(ScalarValue::Utf8(None)) - | ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) => { - return Ok(null_result(subject_len(&args[0]))); - } - _ => { - return exec_err!("regexp_extract pattern must be a scalar string"); - } - }; - - let regex = Regex::new(pattern).map_err(|e| { - DataFusionError::Execution(format!( - "The value of parameter `regexp` in `regexp_extract` is invalid: '{pattern}' ({e})" - )) - })?; - - let group_count = regex.captures_len() as i32 - 1; - if idx < 0 || idx > group_count { - return Err(DataFusionError::Execution(format!( - "The value of parameter `idx` in `regexp_extract` is invalid: \ - Expects group index between 0 and {group_count}, but got {idx}." - ))); - } - let group_idx = idx as usize; - - match &args[0] { + match subject { ColumnarValue::Array(array) => match array.data_type() { DataType::Utf8 => { let strings = as_generic_string_array::(array.as_ref())?; - Ok(ColumnarValue::Array(extract_array::( + Ok(ColumnarValue::Array(extract_array( strings, ®ex, group_idx, ))) } DataType::LargeUtf8 => { let strings = as_generic_string_array::(array.as_ref())?; - Ok(ColumnarValue::Array(extract_array::( + Ok(ColumnarValue::Array(extract_array( strings, ®ex, group_idx, ))) } @@ -115,13 +77,16 @@ pub fn spark_regexp_extract(args: &[ColumnarValue]) -> DataFusionResult( array: &GenericStringArray, regex: &Regex, group_idx: usize, ) -> ArrayRef { - let mut builder = - GenericStringBuilder::::with_capacity(array.len(), array.value_data().len()); + let mut builder = StringBuilder::with_capacity(array.len(), array.value_data().len()); for i in 0..array.len() { if array.is_null(i) { builder.append_null(); @@ -146,13 +111,6 @@ fn extract_one(input: &str, regex: &Regex, group_idx: usize) -> String { } } -fn subject_len(value: &ColumnarValue) -> Option { - match value { - ColumnarValue::Array(a) => Some(a.len()), - ColumnarValue::Scalar(_) => None, - } -} - fn null_result(len: Option) -> ColumnarValue { match len { Some(n) => ColumnarValue::Array(Arc::new(StringArray::new_null(n))), @@ -163,7 +121,8 @@ fn null_result(len: Option) -> ColumnarValue { #[cfg(test)] mod tests { use super::*; - use arrow::array::StringArray; + use arrow::array::{LargeStringArray, StringArray}; + use datafusion::common::DataFusionError; fn run(args: Vec) -> DataFusionResult>> { let result = spark_regexp_extract(&args)?; @@ -171,8 +130,8 @@ mod tests { ColumnarValue::Array(arr) => { let s = arr .as_any() - .downcast_ref::>() - .expect("expected Utf8 array"); + .downcast_ref::() + .expect("expected Utf8 array (regexp_extract must always return StringArray)"); Ok((0..s.len()) .map(|i| { if s.is_null(i) { @@ -298,4 +257,32 @@ mod tests { .unwrap(); assert!(err.to_string().contains("`regexp`")); } + + /// `LargeUtf8` subject must still produce a `StringArray` (i32 offsets) so the result type + /// matches Spark's `RegExpExtract.dataType` = `StringType`. Regression for the bug where + /// `extract_array::` used to build a `LargeStringArray` and trip a type mismatch. + #[test] + fn large_utf8_subject_returns_utf8_array() { + let array = ColumnarValue::Array(Arc::new(LargeStringArray::from(vec![ + Some("100-200"), + None, + Some("foo-bar"), + ]))); + let result = spark_regexp_extract(&[array, pattern(r"(\d+)-(\d+)"), idx(1)]).unwrap(); + match result { + ColumnarValue::Array(arr) => { + arr.as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal(format!( + "expected StringArray, got {:?}", + arr.data_type() + )) + }) + .unwrap(); + assert_eq!(arr.len(), 3); + } + other => panic!("unexpected result: {other:?}"), + } + } } diff --git a/native/spark-expr/src/string_funcs/regexp_extract_all.rs b/native/spark-expr/src/string_funcs/regexp_extract_all.rs index 6c5d25d161..adf56820b8 100644 --- a/native/spark-expr/src/string_funcs/regexp_extract_all.rs +++ b/native/spark-expr/src/string_funcs/regexp_extract_all.rs @@ -22,13 +22,14 @@ use arrow::array::{ use arrow::buffer::{NullBuffer, OffsetBuffer}; use arrow::datatypes::{DataType, Field}; use datafusion::common::{ - cast::as_generic_string_array, exec_err, DataFusionError, Result as DataFusionResult, - ScalarValue, + cast::as_generic_string_array, exec_err, Result as DataFusionResult, ScalarValue, }; use datafusion::logical_expr::ColumnarValue; use regex::Regex; use std::sync::Arc; +use super::regexp_extract_common::{parse_args, ParsedArgs}; + /// Spark-compatible `regexp_extract_all(subject, pattern, idx)`. /// /// Returns an array of all substrings of `subject` matched by group `idx` across every @@ -40,56 +41,16 @@ use std::sync::Arc; /// Note: this uses the Rust `regex` crate, whose syntax differs from Java's regex engine in /// some ways. The expression is therefore reported as Incompatible. pub fn spark_regexp_extract_all(args: &[ColumnarValue]) -> DataFusionResult { - if args.len() < 2 || args.len() > 3 { - return exec_err!( - "regexp_extract_all expects 2 or 3 arguments (subject, pattern, [idx]), got {}", - args.len() - ); - } - - let idx: i32 = if args.len() == 3 { - match &args[2] { - ColumnarValue::Scalar(ScalarValue::Int32(Some(i))) => *i, - ColumnarValue::Scalar(ScalarValue::Int32(None)) => { - return Ok(null_result(subject_len(&args[0]))); - } - _ => { - return exec_err!("regexp_extract_all idx must be an Int32 scalar"); - } - } - } else { - 1 + let (regex, group_idx, subject) = match parse_args("regexp_extract_all", args)? { + ParsedArgs::Parsed { + regex, + group_idx, + subject, + } => (regex, group_idx, subject), + ParsedArgs::NullResult { len } => return Ok(null_result(len)), }; - let pattern: &str = match &args[1] { - ColumnarValue::Scalar(ScalarValue::Utf8(Some(p))) - | ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(p))) => p, - ColumnarValue::Scalar(ScalarValue::Utf8(None)) - | ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) => { - return Ok(null_result(subject_len(&args[0]))); - } - _ => { - return exec_err!("regexp_extract_all pattern must be a scalar string"); - } - }; - - let regex = Regex::new(pattern).map_err(|e| { - DataFusionError::Execution(format!( - "The value of parameter `regexp` in `regexp_extract_all` is invalid: \ - '{pattern}' ({e})" - )) - })?; - - let group_count = regex.captures_len() as i32 - 1; - if idx < 0 || idx > group_count { - return Err(DataFusionError::Execution(format!( - "The value of parameter `idx` in `regexp_extract_all` is invalid: \ - Expects group index between 0 and {group_count}, but got {idx}." - ))); - } - let group_idx = idx as usize; - - match &args[0] { + match subject { ColumnarValue::Array(array) => match array.data_type() { DataType::Utf8 => { let strings = as_generic_string_array::(array.as_ref())?; @@ -124,6 +85,9 @@ pub fn spark_regexp_extract_all(args: &[ColumnarValue]) -> DataFusionResult( array: &GenericStringArray, regex: &Regex, @@ -170,13 +134,6 @@ fn extract_one(input: &str, regex: &Regex, group_idx: usize) -> Vec { .collect() } -fn subject_len(value: &ColumnarValue) -> Option { - match value { - ColumnarValue::Array(a) => Some(a.len()), - ColumnarValue::Scalar(_) => None, - } -} - fn null_result(len: Option) -> ColumnarValue { match len { Some(n) => ColumnarValue::Array(null_list_array(n)), @@ -208,7 +165,7 @@ fn scalar_null_list() -> ScalarValue { #[cfg(test)] mod tests { use super::*; - use arrow::array::StringArray; + use arrow::array::{LargeStringArray, StringArray}; fn run(args: Vec) -> DataFusionResult>>> { let result = spark_regexp_extract_all(&args)?; @@ -386,4 +343,31 @@ mod tests { .unwrap(); assert!(err.to_string().contains("`regexp`")); } + + /// Regression: `LargeUtf8` subject must still produce a `ListArray` whose inner values + /// are a `StringArray` (i32 offsets), matching Spark's `RegExpExtractAll.dataType` = + /// `ArrayType(StringType)`. + #[test] + fn large_utf8_subject_returns_inner_utf8() { + let array = ColumnarValue::Array(Arc::new(LargeStringArray::from(vec![ + Some("1 2 3"), + None, + Some("4 5"), + ]))); + let result = spark_regexp_extract_all(&[array, pattern(r"(\d)"), idx(1)]).unwrap(); + let list = match result { + ColumnarValue::Array(arr) => arr, + other => panic!("unexpected result: {other:?}"), + }; + let list = list + .as_any() + .downcast_ref::() + .expect("expected ListArray"); + assert_eq!(list.len(), 3); + // Inner values must be StringArray, not LargeStringArray + list.values() + .as_any() + .downcast_ref::() + .expect("inner values must be StringArray"); + } } diff --git a/native/spark-expr/src/string_funcs/regexp_extract_common.rs b/native/spark-expr/src/string_funcs/regexp_extract_common.rs new file mode 100644 index 0000000000..5c02a46b5f --- /dev/null +++ b/native/spark-expr/src/string_funcs/regexp_extract_common.rs @@ -0,0 +1,112 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// +//! Shared helpers for `regexp_extract` and `regexp_extract_all`. Both UDFs accept +//! `(subject, pattern, [idx])` and need to validate the same things in the same way; this +//! module centralizes that parsing and the null short-circuit so each UDF is left with +//! its own per-row loop. + +use datafusion::common::{exec_err, DataFusionError, Result as DataFusionResult, ScalarValue}; +use datafusion::logical_expr::ColumnarValue; +use regex::Regex; + +/// Result of parsing the `(subject, pattern, [idx])` arguments shared by `regexp_extract` +/// and `regexp_extract_all`. +pub(super) enum ParsedArgs<'a> { + /// Arguments validated; the caller can run its per-row extraction loop. + Parsed { + regex: Regex, + group_idx: usize, + subject: &'a ColumnarValue, + }, + /// A scalar pattern or idx was null; the result is null for every input row. `len` is + /// `Some(n)` when the subject is an array (one null per row) and `None` for a scalar + /// subject. + NullResult { len: Option }, +} + +/// Validate `(subject, pattern, [idx])` and return either a compiled `Regex` and group index +/// ready for extraction, or a request to short-circuit to a null result. +pub(super) fn parse_args<'a>( + fn_name: &'static str, + args: &'a [ColumnarValue], +) -> DataFusionResult> { + if args.len() < 2 || args.len() > 3 { + return exec_err!( + "{fn_name} expects 2 or 3 arguments (subject, pattern, [idx]), got {}", + args.len() + ); + } + + let idx: i32 = if args.len() == 3 { + match &args[2] { + ColumnarValue::Scalar(ScalarValue::Int32(Some(i))) => *i, + ColumnarValue::Scalar(ScalarValue::Int32(None)) => { + return Ok(ParsedArgs::NullResult { + len: subject_len(&args[0]), + }); + } + _ => { + return exec_err!("{fn_name} idx must be an Int32 scalar"); + } + } + } else { + 1 + }; + + let pattern: &str = match &args[1] { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(p))) + | ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(p))) => p, + ColumnarValue::Scalar(ScalarValue::Utf8(None)) + | ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) => { + return Ok(ParsedArgs::NullResult { + len: subject_len(&args[0]), + }); + } + _ => { + return exec_err!("{fn_name} pattern must be a scalar string"); + } + }; + + let regex = Regex::new(pattern).map_err(|e| { + DataFusionError::Execution(format!( + "The value of parameter `regexp` in `{fn_name}` is invalid: '{pattern}' ({e})" + )) + })?; + + let group_count = regex.captures_len() as i32 - 1; + if idx < 0 || idx > group_count { + return Err(DataFusionError::Execution(format!( + "The value of parameter `idx` in `{fn_name}` is invalid: \ + Expects group index between 0 and {group_count}, but got {idx}." + ))); + } + + Ok(ParsedArgs::Parsed { + regex, + group_idx: idx as usize, + subject: &args[0], + }) +} + +/// Returns `Some(row_count)` for an array subject and `None` for a scalar. +pub(super) fn subject_len(value: &ColumnarValue) -> Option { + match value { + ColumnarValue::Array(a) => Some(a.len()), + ColumnarValue::Scalar(_) => None, + } +}