Skip to content

ESQL: Limit Replace function memory usage #127924

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/127924.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 127924
summary: Limit Replace function memory usage
area: ES|QL
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import java.util.List;

import static java.util.Collections.emptyList;
import static org.elasticsearch.common.unit.ByteSizeUnit.MB;

/**
* A {@code ScalarFunction} is a {@code Function} that takes values from some
Expand All @@ -22,6 +23,14 @@
*/
public abstract class ScalarFunction extends Function {

/**
* Limit for the BytesRef return of functions.
* <p>
* To be used when there's no CircuitBreaking, as an arbitrary measure to limit memory usage.
* </p>
*/
public static final long MAX_BYTES_REF_RESULT_SIZE = MB.toBytes(1);

protected ScalarFunction(Source source) {
super(source, emptyList());
}
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import java.util.Arrays;
import java.util.List;

import static org.elasticsearch.common.unit.ByteSizeUnit.MB;
import static org.elasticsearch.compute.ann.Fixed.Scope.THREAD_LOCAL;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND;
Expand All @@ -40,8 +39,6 @@
public class Repeat extends EsqlScalarFunction implements OptionalArgument {
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Repeat", Repeat::new);

static final long MAX_REPEATED_LENGTH = MB.toBytes(1);

private final Expression str;
private final Expression number;

Expand Down Expand Up @@ -123,9 +120,9 @@ static BytesRef process(

static BytesRef processInner(BreakingBytesRefBuilder scratch, BytesRef str, int number) {
int repeatedLen = str.length * number;
if (repeatedLen > MAX_REPEATED_LENGTH) {
if (repeatedLen > MAX_BYTES_REF_RESULT_SIZE) {
throw new IllegalArgumentException(
"Creating repeated strings with more than [" + MAX_REPEATED_LENGTH + "] bytes is not supported"
"Creating repeated strings with more than [" + MAX_BYTES_REF_RESULT_SIZE + "] bytes is not supported"
);
}
scratch.grow(repeatedLen);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.regex.PatternSyntaxException;

Expand Down Expand Up @@ -121,24 +122,63 @@ public boolean foldable() {
return str.foldable() && regex.foldable() && newStr.foldable();
}

@Evaluator(extraName = "Constant", warnExceptions = PatternSyntaxException.class)
@Evaluator(extraName = "Constant", warnExceptions = IllegalArgumentException.class)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PatternSyntaxException is an IllegalArgumentException. If we choose another exception for this PR, we would have to restore PatternSyntaxException

static BytesRef process(BytesRef str, @Fixed Pattern regex, BytesRef newStr) {
if (str == null || regex == null || newStr == null) {
return null;
}
return new BytesRef(regex.matcher(str.utf8ToString()).replaceAll(newStr.utf8ToString()));
return safeReplace(str, regex, newStr);
}

@Evaluator(warnExceptions = PatternSyntaxException.class)
@Evaluator(warnExceptions = IllegalArgumentException.class)
static BytesRef process(BytesRef str, BytesRef regex, BytesRef newStr) {
if (str == null) {
return null;
}

if (regex == null || newStr == null) {
return str;
}
return new BytesRef(str.utf8ToString().replaceAll(regex.utf8ToString(), newStr.utf8ToString()));
return safeReplace(str, Pattern.compile(regex.utf8ToString()), newStr);
}

/**
* Executes a Replace without surpassing the memory limit.
*/
private static BytesRef safeReplace(BytesRef strBytesRef, Pattern regex, BytesRef newStrBytesRef) {
String str = strBytesRef.utf8ToString();
Matcher m = regex.matcher(str);
if (false == m.find()) {
return strBytesRef;
}
String newStr = newStrBytesRef.utf8ToString();

// Count potential groups (E.g. "$1") used in the replacement
int constantReplacementLength = newStr.length();
int groupsInReplacement = 0;
for (int i = 0; i < newStr.length(); i++) {
if (newStr.charAt(i) == '$') {
groupsInReplacement++;
constantReplacementLength -= 2;
i++;
}
}

// Initialize the buffer with an approximate size for the first replacement
StringBuilder result = new StringBuilder(str.length() + newStr.length() + 8);
do {
int matchSize = m.end() - m.start();
int potentialReplacementSize = constantReplacementLength + groupsInReplacement * matchSize;
int remainingStr = str.length() - m.end();
if (result.length() + potentialReplacementSize + remainingStr > MAX_BYTES_REF_RESULT_SIZE) {
throw new IllegalArgumentException(
"Creating strings with more than [" + MAX_BYTES_REF_RESULT_SIZE + "] bytes is not supported"
);
}

m.appendReplacement(result, newStr);
} while (m.find());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comparing this with the implementation of replaceAll, this seems to be equivalent, so I think we don't accidentally change the semantics. Nice.

m.appendTail(result);
return new BytesRef(result.toString());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ public final void testEvaluateInManyThreads() throws ExecutionException, Interru
if (testCase.getExpectedBuildEvaluatorWarnings() != null) {
assertWarnings(testCase.getExpectedBuildEvaluatorWarnings());
}

ExecutorService exec = Executors.newFixedThreadPool(threads);
try {
List<Future<?>> futures = new ArrayList<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.elasticsearch.compute.test.TestBlockFactory;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
import org.elasticsearch.xpack.esql.core.expression.function.scalar.ScalarFunction;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.core.type.EsField;
Expand All @@ -45,14 +46,14 @@ public class RepeatStaticTests extends ESTestCase {

public void testAlmostTooBig() {
String str = randomAlphaOfLength(1);
int number = (int) Repeat.MAX_REPEATED_LENGTH;
int number = (int) ScalarFunction.MAX_BYTES_REF_RESULT_SIZE;
String repeated = process(str, number);
assertThat(repeated, equalTo(str.repeat(number)));
}

public void testTooBig() {
String str = randomAlphaOfLength(1);
int number = (int) Repeat.MAX_REPEATED_LENGTH + 1;
int number = (int) ScalarFunction.MAX_BYTES_REF_RESULT_SIZE + 1;
String repeated = process(str, number);
assertNull(repeated);
assertWarnings(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.esql.expression.function.scalar.string;

import org.apache.lucene.util.BytesRef;
import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.MockBigArrays;
import org.elasticsearch.common.util.PageCacheRecycler;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.BlockUtils;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.test.TestBlockFactory;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
import org.elasticsearch.xpack.esql.core.expression.function.scalar.ScalarFunction;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.core.type.EsField;
import org.elasticsearch.xpack.esql.expression.function.AbstractScalarFunctionTestCase;
import org.junit.After;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;

import static org.hamcrest.Matchers.equalTo;

/**
* These tests create rows that are 1MB in size. Test classes
* which extend AbstractScalarFunctionTestCase rerun test cases with
* many randomized inputs. Unfortunately, tests are run with
* limited memory, and instantiating many copies of these
* tests with large rows causes out of memory.
*/
public class ReplaceStaticTests extends ESTestCase {

public void testLimit() {
int textLength = (int) ScalarFunction.MAX_BYTES_REF_RESULT_SIZE / 10;
String text = randomAlphaOfLength((int) ScalarFunction.MAX_BYTES_REF_RESULT_SIZE / 10);
String regex = "^(.+)$";

// 10 times the original text + the remainder
String extraString = "a".repeat((int) ScalarFunction.MAX_BYTES_REF_RESULT_SIZE % 10);
assert textLength * 10 + extraString.length() == ScalarFunction.MAX_BYTES_REF_RESULT_SIZE;
String newStr = "$0$0$0$0$0$0$0$0$0$0" + extraString;

String result = process(text, regex, newStr);
assertThat(result, equalTo(newStr.replaceAll("\\$\\d", text)));
}

public void testTooBig() {
String textAndNewStr = randomAlphaOfLength((int) (ScalarFunction.MAX_BYTES_REF_RESULT_SIZE / 10));
String regex = ".";

String result = process(textAndNewStr, regex, textAndNewStr);
assertNull(result);
assertWarnings(
"Line -1:-1: evaluation of [] failed, treating result as null. Only first 20 failures recorded.",
"Line -1:-1: java.lang.IllegalArgumentException: "
+ "Creating strings with more than ["
+ ScalarFunction.MAX_BYTES_REF_RESULT_SIZE
+ "] bytes is not supported"
);
}

public void testTooBigWithGroups() {
int textLength = (int) ScalarFunction.MAX_BYTES_REF_RESULT_SIZE / 10;
String text = randomAlphaOfLength(textLength);
String regex = "(.+)";

// 10 times the original text + the remainder + 1
String extraString = "a".repeat(1 + (int) ScalarFunction.MAX_BYTES_REF_RESULT_SIZE % 10);
assert textLength * 10 + extraString.length() == ScalarFunction.MAX_BYTES_REF_RESULT_SIZE + 1;
String newStr = "$0$1$0$1$0$1$0$1$0$1" + extraString;

String result = process(text, regex, newStr);
assertNull(result);
assertWarnings(
"Line -1:-1: evaluation of [] failed, treating result as null. Only first 20 failures recorded.",
"Line -1:-1: java.lang.IllegalArgumentException: "
+ "Creating strings with more than ["
+ ScalarFunction.MAX_BYTES_REF_RESULT_SIZE
+ "] bytes is not supported"
);
}

public String process(String text, String regex, String newStr) {
try (
var eval = AbstractScalarFunctionTestCase.evaluator(
new Replace(
Source.EMPTY,
field("text", DataType.KEYWORD),
field("regex", DataType.KEYWORD),
field("newStr", DataType.KEYWORD)
)
).get(driverContext());
Block block = eval.eval(row(List.of(new BytesRef(text), new BytesRef(regex), new BytesRef(newStr))));
) {
return block.isNull(0) ? null : ((BytesRef) BlockUtils.toJavaObject(block, 0)).utf8ToString();
}
}

/**
* The following fields and methods were borrowed from AbstractScalarFunctionTestCase
*/
private final List<CircuitBreaker> breakers = Collections.synchronizedList(new ArrayList<>());

private static Page row(List<Object> values) {
return new Page(1, BlockUtils.fromListRow(TestBlockFactory.getNonBreakingInstance(), values));
}

private static FieldAttribute field(String name, DataType type) {
return new FieldAttribute(Source.synthetic(name), name, new EsField(name, type, Map.of(), true));
}

private DriverContext driverContext() {
BigArrays bigArrays = new MockBigArrays(PageCacheRecycler.NON_RECYCLING_INSTANCE, ByteSizeValue.ofMb(256)).withCircuitBreaking();
CircuitBreaker breaker = bigArrays.breakerService().getBreaker(CircuitBreaker.REQUEST);
breakers.add(breaker);
return new DriverContext(bigArrays, new BlockFactory(breaker, bigArrays));
}

@After
public void allMemoryReleased() {
for (CircuitBreaker breaker : breakers) {
assertThat(breaker.getUsed(), equalTo(0L));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,30 @@ public static Iterable<Object[]> parameters() {
)
);

// Groups
suppliers.add(fixedCase("Full group", "Cats are awesome", ".+", "<$0>", "<Cats are awesome>"));
suppliers.add(
fixedCase("Nested groups", "A cat is great, a cat is awesome", "\\b([Aa] (\\w+)) is (\\w+)\\b", "$1$2", "A catcat, a catcat")
);
suppliers.add(
fixedCase(
"Multiple groups",
"Cats are awesome",
"(\\w+) (.+)",
"$0 -> $1 and dogs $2",
"Cats are awesome -> Cats and dogs are awesome"
)
);

// Errors
suppliers.add(new TestCaseSupplier("syntax error", List.of(DataType.KEYWORD, DataType.KEYWORD, DataType.KEYWORD), () -> {
String text = randomAlphaOfLength(10);
String invalidRegex = "[";
String newStr = randomAlphaOfLength(5);
return new TestCaseSupplier.TestCase(
List.of(
new TestCaseSupplier.TypedData(new BytesRef(text), DataType.KEYWORD, "str"),
new TestCaseSupplier.TypedData(new BytesRef(invalidRegex), DataType.KEYWORD, "oldStr"),
new TestCaseSupplier.TypedData(new BytesRef(invalidRegex), DataType.KEYWORD, "regex"),
new TestCaseSupplier.TypedData(new BytesRef(newStr), DataType.KEYWORD, "newStr")
),
"ReplaceEvaluator[str=Attribute[channel=0], regex=Attribute[channel=1], newStr=Attribute[channel=2]]",
Expand Down
Loading