Skip to content

Commit fd56aa5

Browse files
committed
Fix regex replacement parser Java spec gaps
Signed-off-by: Allen Xu <allxu@nvidia.com>
1 parent 1e10069 commit fd56aa5

5 files changed

Lines changed: 189 additions & 55 deletions

File tree

integration_tests/src/main/python/regexp_test.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,59 @@ def test_regexp_replace():
449449
'regexp_replace(a, "a|b|c", "A")'),
450450
conf=_regexp_conf)
451451

452+
453+
# https://github.com/NVIDIA/spark-rapids/issues/14742
454+
# Replacement-string parser must match java.util.regex.Matcher#appendReplacement.
455+
# Use the DataFrame API rather than selectExpr because Spark SQL variable substitution
456+
# expands ${...} inside SQL string literals before regexp_replace sees it.
457+
def test_regexp_replace_subbug1_backslash_digit_is_literal_14742():
458+
from pyspark.sql.functions import regexp_replace, col
459+
assert_gpu_and_cpu_are_equal_collect(
460+
lambda spark: spark.createDataFrame([("abc",)], ["a"]).select(
461+
regexp_replace(col("a"), "(a)", "\\1")),
462+
conf=_regexp_conf)
463+
464+
465+
@allow_non_gpu('ProjectExec', 'RegExpReplace')
466+
def test_regexp_replace_subbug2_trailing_backslash_throws_14742():
467+
from pyspark.sql.functions import regexp_replace, col
468+
assert_gpu_and_cpu_error(
469+
lambda spark: spark.createDataFrame([("a",)], ["a"]).select(
470+
regexp_replace(col("a"), "a", "\\")).collect(),
471+
conf=_regexp_conf,
472+
error_message="character to be escaped is missing")
473+
474+
475+
@allow_non_gpu('ProjectExec', 'RegExpReplace')
476+
def test_regexp_replace_subbug3_dollar_non_digit_throws_14742():
477+
from pyspark.sql.functions import regexp_replace, col
478+
assert_gpu_and_cpu_error(
479+
lambda spark: spark.createDataFrame([("a",)], ["a"]).select(
480+
regexp_replace(col("a"), "a", "$x")).collect(),
481+
conf=_regexp_conf,
482+
error_message="Illegal group reference")
483+
484+
485+
@allow_non_gpu('ProjectExec', 'RegExpReplace')
486+
def test_regexp_replace_subbug4_digit_leading_named_group_throws_14742():
487+
from pyspark.sql.functions import regexp_replace, col
488+
assert_gpu_and_cpu_error(
489+
lambda spark: spark.createDataFrame([("a",)], ["a"]).select(
490+
regexp_replace(col("a"), "(a)", "${1}")).collect(),
491+
conf=_regexp_conf,
492+
error_message="capturing group name {1} starts with digit character")
493+
494+
495+
@allow_non_gpu('ProjectExec', 'RegExpReplace')
496+
def test_regexp_replace_subbug5_unknown_named_group_throws_14742():
497+
from pyspark.sql.functions import regexp_replace, col
498+
assert_gpu_and_cpu_error(
499+
lambda spark: spark.createDataFrame([("a",)], ["a"]).select(
500+
regexp_replace(col("a"), "(a)", "${name}")).collect(),
501+
conf=_regexp_conf,
502+
error_message="No group with name")
503+
504+
452505
@pytest.mark.skipif(is_before_spark_320(), reason='regexp is synonym for RLike starting in Spark 3.2.0')
453506
def test_regexp():
454507
gen = mk_str_gen('[abcd]{1,3}')

sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala

Lines changed: 38 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -364,56 +364,57 @@ class RegexParser(pattern: String) {
364364
}
365365

366366
private def parseBackrefOrEscaped(): RegexAST = {
367-
val start = pos
368-
369-
consumeInt match {
370-
case Some(refNum) =>
371-
RegexBackref(refNum)
367+
peek() match {
372368
case None =>
373-
pos = start
374-
RegexChar('\\')
369+
throw new RegexUnsupportedException(
370+
"character to be escaped is missing", Some(pos))
371+
case Some(_) =>
372+
RegexSequence(ListBuffer(RegexChar('\\'), RegexChar(consume())))
375373
}
376374
}
377375

378376
private def parseBackrefOrLiteralDollar(): RegexAST = {
379-
val start = pos
380-
381-
def treatAsLiteralDollar() = {
382-
pos = start
383-
RegexChar('$')
384-
}
385-
386377
peek() match {
387378
case Some('{') =>
388379
consumeExpected('{')
389-
val num = consumeInt()
390-
if (peek().contains('}')) {
391-
consumeExpected('}')
392-
num match {
393-
case Some(_) =>
380+
peek() match {
381+
case Some(ch) if isAsciiDigit(ch) =>
382+
throw new RegexUnsupportedException(
383+
"Illegal group reference: group name starts with digit character",
384+
Some(pos))
385+
case Some(ch) if isLetter(ch) =>
386+
val nameStart = pos
387+
while (!eof() && peek().exists(c => isLetter(c) || isAsciiDigit(c))) {
388+
skip()
389+
}
390+
val name = pattern.substring(nameStart, pos)
391+
if (!peek().contains('}')) {
394392
throw new RegexUnsupportedException(
395-
"Numeric `${N}` backref in replacement string is not supported on GPU " +
396-
"(Java's Matcher.appendReplacement rejects this syntax)",
397-
Some(start))
398-
case _ =>
399-
treatAsLiteralDollar()
400-
}
401-
} else {
402-
treatAsLiteralDollar()
403-
}
404-
case Some(ch) if ch >= '1' && ch <= '9' =>
405-
val num = consumeInt()
406-
num match {
407-
case Some(n) =>
408-
RegexBackref(n)
393+
"Illegal group reference: malformed " + "$" + "{name} reference",
394+
Some(pos))
395+
}
396+
consumeExpected('}')
397+
throw new RegexUnsupportedException(
398+
s"named-group reference $${$name} is not supported on the GPU",
399+
Some(pos))
409400
case _ =>
410-
treatAsLiteralDollar()
401+
throw new RegexUnsupportedException(
402+
"Illegal group reference: empty or malformed " + "$" + "{name}",
403+
Some(pos))
411404
}
405+
case Some(ch) if isAsciiDigit(ch) =>
406+
RegexBackref(consumeInt().get)
412407
case _ =>
413-
treatAsLiteralDollar()
408+
throw new RegexUnsupportedException(
409+
"Illegal group reference", Some(pos))
414410
}
415411
}
416412

413+
private def isLetter(ch: Char): Boolean =
414+
(ch >= 'A' && ch <= 'Z') || (ch >= 'a' && ch <= 'z')
415+
416+
private def isAsciiDigit(ch: Char): Boolean = ch >= '0' && ch <= '9'
417+
417418
private def parseEscapedCharacter(): RegexAST = {
418419
peek() match {
419420
case None =>
@@ -620,7 +621,8 @@ class RegexParser(pattern: String) {
620621

621622
private def consumeInt(): Option[Int] = {
622623
val start = pos
623-
while (!eof() && peek().exists(_.isDigit)) {
624+
// ASCII only: `substring.toInt` rejects Unicode digit codepoints.
625+
while (!eof() && peek().exists(isAsciiDigit)) {
624626
skip()
625627
}
626628
if (start == pos) {

sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1073,13 +1073,12 @@ object GpuRegExpUtils {
10731073
}
10741074

10751075
/**
1076-
* Convert symbols of back-references if input string contains any.
1077-
* In spark's regex rule, there are two patterns of back-references:
1078-
* \group_index and \$group_index
1079-
* This method transforms above two patterns into cuDF pattern \${group_index}, except they are
1080-
* preceded by escape character.
1076+
* Convert numbered `$group_index` back-reference symbols into cuDF's `${group_index}` form.
1077+
* Java `Matcher#appendReplacement` treats `\digit` as the literal digit character, not as a
1078+
* back-reference. Escaped pairs are kept verbatim here and are later normalized by
1079+
* `unescapeReplaceString`.
10811080
*
1082-
* Java's `Matcher.appendReplacement` reads the digits after `$` or `\` one at a time and
1081+
* Java's `Matcher.appendReplacement` reads the digits after `$` one at a time and
10831082
* stops as soon as adding the next digit would make the running group index exceed the
10841083
* actual capture-group count; any further digits are treated as literal characters
10851084
* (greedy-with-backoff). When `numCaptureGroups` is negative the caller is opting out of the
@@ -1114,8 +1113,7 @@ object GpuRegExpUtils {
11141113
b.append(rep.charAt(i))
11151114
i += 1
11161115
}
1117-
} else if (Seq('$', '\\').contains(rep.charAt(i))
1118-
&& i + 1 < rep.length && rep.charAt(i + 1).isDigit) {
1116+
} else if (rep.charAt(i) == '$' && i + 1 < rep.length && rep.charAt(i + 1).isDigit) {
11191117

11201118
// Consume digits one at a time. If the running group index would exceed the actual
11211119
// capture-group count, stop and leave the remaining digits as literals. When no digit
@@ -1153,8 +1151,7 @@ object GpuRegExpUtils {
11531151
i = k
11541152
}
11551153
} else if (rep.charAt(i) == '\\' && i + 1 < rep.length) {
1156-
// skip potential escape sequences like `\$` or `\\`; `\digit` is handled by the
1157-
// greedy-with-backoff branch above.
1154+
// Keep `\X` pairs verbatim; `unescapeReplaceString` strips the leading backslash.
11581155
b.append('\\').append(rep.charAt(i + 1))
11591156
i += 2
11601157
} else {

tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala

Lines changed: 89 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -283,15 +283,97 @@ class RegularExpressionParserSuite extends AnyFunSuite {
283283
RegexChar('$'))))
284284
}
285285

286-
test("replacement: numeric braced backref rejected (Java spec)") {
287-
val brace = "$" + "{"
288-
val cases = Seq(s"[${brace}2}]", s"${brace}1}", s"${brace}12}",
289-
s"a${brace}3}b", s"${brace}0}")
290-
for (rep <- cases) {
286+
test("issue-14742-subbug1: \\N in replacement is the literal character N, not a backref") {
287+
val repl = new RegexParser("\\1").parseReplacement(numCaptureGroups = 1)
288+
assert(repl.parts.toList === List(RegexChar('\\'), RegexChar('1')))
289+
}
290+
291+
test("issue-14742-subbug1: \\a in replacement is the literal character a") {
292+
val repl = new RegexParser("\\a").parseReplacement(numCaptureGroups = 0)
293+
assert(repl.parts.toList === List(RegexChar('\\'), RegexChar('a')))
294+
}
295+
296+
test("issue-14742-subbug2: trailing \\ in replacement throws") {
297+
val ex = intercept[RegexUnsupportedException] {
298+
new RegexParser("\\").parseReplacement(numCaptureGroups = 0)
299+
}
300+
assert(ex.getMessage.contains("character to be escaped is missing"))
301+
}
302+
303+
test("issue-14742-subbug3: bare $X for non-digit X throws") {
304+
val ex = intercept[RegexUnsupportedException] {
305+
new RegexParser("$x").parseReplacement(numCaptureGroups = 0)
306+
}
307+
assert(ex.getMessage.contains("Illegal group reference"))
308+
}
309+
310+
test("issue-14742-subbug3: trailing bare $ throws") {
311+
val ex = intercept[RegexUnsupportedException] {
312+
new RegexParser("$").parseReplacement(numCaptureGroups = 0)
313+
}
314+
assert(ex.getMessage.contains("Illegal group reference"))
315+
}
316+
317+
test("issue-14742-subbug4: dollar-brace-digit-brace throws") {
318+
val ex = intercept[RegexUnsupportedException] {
319+
new RegexParser("$" + "{1}").parseReplacement(numCaptureGroups = 1)
320+
}
321+
assert(ex.getMessage.contains("Illegal group reference"))
322+
assert(ex.getMessage.contains("digit"))
323+
}
324+
325+
test("issue-14742-subbug5: dollar-brace-name-brace for named group is not supported on GPU") {
326+
val ex = intercept[RegexUnsupportedException] {
327+
new RegexParser("$" + "{name}").parseReplacement(numCaptureGroups = 1)
328+
}
329+
assert(ex.getMessage.contains("named-group reference"))
330+
}
331+
332+
test("issue-14742-subbug5: dollar-brace-name with missing closing brace throws") {
333+
val ex = intercept[RegexUnsupportedException] {
334+
new RegexParser("$" + "{name").parseReplacement(numCaptureGroups = 0)
335+
}
336+
assert(ex.getMessage.contains("Illegal group reference"))
337+
}
338+
339+
test("issue-14742: dollar-brace with empty body throws") {
340+
val ex = intercept[RegexUnsupportedException] {
341+
new RegexParser("$" + "{}").parseReplacement(numCaptureGroups = 0)
342+
}
343+
assert(ex.getMessage.contains("Illegal group reference"))
344+
}
345+
346+
test("issue-14742: numbered backref $0 still works") {
347+
val repl = new RegexParser("$0").parseReplacement(numCaptureGroups = 0)
348+
assert(repl.parts.toList === List(RegexBackref(0)))
349+
}
350+
351+
test("issue-14742: numbered backref $1 still works") {
352+
val repl = new RegexParser("$1").parseReplacement(numCaptureGroups = 1)
353+
assert(repl.parts.toList === List(RegexBackref(1)))
354+
}
355+
356+
test("issue-14742: numbered backref $12 still consumes max digits") {
357+
val repl = new RegexParser("$12").parseReplacement(numCaptureGroups = 12)
358+
assert(repl.parts.toList === List(RegexBackref(12)))
359+
}
360+
361+
test("issue-14742: escaped metachar \\$ in replacement keeps the \\ pair") {
362+
val repl = new RegexParser("\\$").parseReplacement(numCaptureGroups = 0)
363+
assert(repl.parts.toList === List(RegexChar('\\'), RegexChar('$')))
364+
}
365+
366+
test("issue-14742: escaped backslash \\\\ in replacement keeps the \\ pair") {
367+
val repl = new RegexParser("\\\\").parseReplacement(numCaptureGroups = 0)
368+
assert(repl.parts.toList === List(RegexChar('\\'), RegexChar('\\')))
369+
}
370+
371+
test("issue-14742: non-ASCII Unicode digit after `$` triggers GPU fallback") {
372+
for (rep <- Seq("", "$१", "")) {
291373
val e = intercept[RegexUnsupportedException] {
292-
new RegexParser(rep).parseReplacement(4)
374+
new RegexParser(rep).parseReplacement(numCaptureGroups = 4)
293375
}
294-
assert(e.getMessage.contains("backref in replacement string is not supported"),
376+
assert(e.getMessage.startsWith("Illegal group reference"),
295377
s"unexpected message for replacement '$rep': ${e.getMessage}")
296378
}
297379
}

tests/src/test/scala/com/nvidia/spark/rapids/StringFunctionSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,8 +239,8 @@ class RegExpUtilsSuite extends AnyFunSuite {
239239
(2, "$2", true, open + "2}"),
240240
// 0 groups, "$1": legacy path -- emit ${1} so cuDF surfaces the error.
241241
(0, "$1", true, open + "1}"),
242-
// Same shape with backslash backref.
243-
(2, "\\12", true, open + "1}2"),
242+
// Java replacement strings treat `\digit` as the literal digit, not a backref.
243+
(2, "\\12", false, "\\12"),
244244
// No digits after `$` -- literal `$`.
245245
(2, "$a", false, "$a"),
246246
// `$0` is the whole-match backref and is always valid (cuDF supports group 0).

0 commit comments

Comments
 (0)