Skip to content

Commit 67293ba

Browse files
authored
ESQL: Speed up VALUES for many buckets (#123073)
Speeds up the VALUES agg when collecting from many buckets. Specifically, this speeds up the algorithm used to `finish` the aggregation. Most specifically, this makes the algorithm more tollerant to large numbers of groups being collected. The old algorithm was `O(n^2)` with the number of groups. The new one is `O(n)` ``` (groups) 1 219.683 ± 1.069 -> 223.477 ± 1.990 ms/op 1000 426.323 ± 75.963 -> 463.670 ± 7.275 ms/op 100000 36690.871 ± 4656.350 -> 7800.332 ± 2775.869 ms/op 200000 89422.113 ± 2972.606 -> 21920.288 ± 3427.962 ms/op 400000 timed out at 10 minutes -> 40051.524 ± 2011.706 ms/op ``` The `1` group version was not changed at all. That's just noise in the measurement. The small bump in the `1000` case is almost certainly worth it and real. The huge drop in the `100000` case is quite real.
1 parent 5b90305 commit 67293ba

File tree

9 files changed

+1031
-177
lines changed

9 files changed

+1031
-177
lines changed

benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/AggregatorBenchmark.java

+3
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@
6060
import java.util.stream.LongStream;
6161
import java.util.stream.Stream;
6262

63+
/**
64+
* Benchmark for many different kinds of aggregator and groupings.
65+
*/
6366
@Warmup(iterations = 5)
6467
@Measurement(iterations = 7)
6568
@BenchmarkMode(Mode.AverageTime)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,339 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.benchmark.compute.operator;
11+
12+
import org.apache.lucene.util.BytesRef;
13+
import org.elasticsearch.common.breaker.NoopCircuitBreaker;
14+
import org.elasticsearch.common.util.BigArrays;
15+
import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier;
16+
import org.elasticsearch.compute.aggregation.AggregatorMode;
17+
import org.elasticsearch.compute.aggregation.ValuesBytesRefAggregatorFunctionSupplier;
18+
import org.elasticsearch.compute.aggregation.ValuesIntAggregatorFunctionSupplier;
19+
import org.elasticsearch.compute.aggregation.ValuesLongAggregatorFunctionSupplier;
20+
import org.elasticsearch.compute.aggregation.blockhash.BlockHash;
21+
import org.elasticsearch.compute.data.Block;
22+
import org.elasticsearch.compute.data.BlockFactory;
23+
import org.elasticsearch.compute.data.BytesRefBlock;
24+
import org.elasticsearch.compute.data.ElementType;
25+
import org.elasticsearch.compute.data.IntBlock;
26+
import org.elasticsearch.compute.data.LongBlock;
27+
import org.elasticsearch.compute.data.LongVector;
28+
import org.elasticsearch.compute.data.Page;
29+
import org.elasticsearch.compute.operator.AggregationOperator;
30+
import org.elasticsearch.compute.operator.DriverContext;
31+
import org.elasticsearch.compute.operator.HashAggregationOperator;
32+
import org.elasticsearch.compute.operator.Operator;
33+
import org.openjdk.jmh.annotations.Benchmark;
34+
import org.openjdk.jmh.annotations.BenchmarkMode;
35+
import org.openjdk.jmh.annotations.Fork;
36+
import org.openjdk.jmh.annotations.Measurement;
37+
import org.openjdk.jmh.annotations.Mode;
38+
import org.openjdk.jmh.annotations.OutputTimeUnit;
39+
import org.openjdk.jmh.annotations.Param;
40+
import org.openjdk.jmh.annotations.Scope;
41+
import org.openjdk.jmh.annotations.State;
42+
import org.openjdk.jmh.annotations.Warmup;
43+
44+
import java.util.ArrayList;
45+
import java.util.HashSet;
46+
import java.util.List;
47+
import java.util.Set;
48+
import java.util.concurrent.TimeUnit;
49+
import java.util.stream.Collectors;
50+
import java.util.stream.IntStream;
51+
import java.util.stream.LongStream;
52+
53+
/**
54+
* Benchmark for the {@code VALUES} aggregator that supports grouping by many many
55+
* many values.
56+
*/
57+
@Warmup(iterations = 5)
58+
@Measurement(iterations = 7)
59+
@BenchmarkMode(Mode.AverageTime)
60+
@OutputTimeUnit(TimeUnit.MILLISECONDS)
61+
@State(Scope.Thread)
62+
@Fork(1)
63+
public class ValuesAggregatorBenchmark {
64+
static final int MIN_BLOCK_LENGTH = 8 * 1024;
65+
private static final int OP_COUNT = 1024;
66+
private static final int UNIQUE_VALUES = 6;
67+
private static final BytesRef[] KEYWORDS = new BytesRef[] {
68+
new BytesRef("Tokyo"),
69+
new BytesRef("Delhi"),
70+
new BytesRef("Shanghai"),
71+
new BytesRef("São Paulo"),
72+
new BytesRef("Mexico City"),
73+
new BytesRef("Cairo") };
74+
static {
75+
assert KEYWORDS.length == UNIQUE_VALUES;
76+
}
77+
78+
private static final BlockFactory blockFactory = BlockFactory.getInstance(
79+
new NoopCircuitBreaker("noop"),
80+
BigArrays.NON_RECYCLING_INSTANCE // TODO real big arrays?
81+
);
82+
83+
static {
84+
// Smoke test all the expected values and force loading subclasses more like prod
85+
try {
86+
for (String groups : ValuesAggregatorBenchmark.class.getField("groups").getAnnotationsByType(Param.class)[0].value()) {
87+
for (String dataType : ValuesAggregatorBenchmark.class.getField("dataType").getAnnotationsByType(Param.class)[0].value()) {
88+
run(Integer.parseInt(groups), dataType, 10);
89+
}
90+
}
91+
} catch (NoSuchFieldException e) {
92+
throw new AssertionError();
93+
}
94+
}
95+
96+
private static final String BYTES_REF = "BytesRef";
97+
private static final String INT = "int";
98+
private static final String LONG = "long";
99+
100+
@Param({ "1", "1000", /*"1000000"*/ })
101+
public int groups;
102+
103+
@Param({ BYTES_REF, INT, LONG })
104+
public String dataType;
105+
106+
private static Operator operator(DriverContext driverContext, int groups, String dataType) {
107+
if (groups == 1) {
108+
return new AggregationOperator(
109+
List.of(supplier(dataType).aggregatorFactory(AggregatorMode.SINGLE, List.of(0)).apply(driverContext)),
110+
driverContext
111+
);
112+
}
113+
List<BlockHash.GroupSpec> groupSpec = List.of(new BlockHash.GroupSpec(0, ElementType.LONG));
114+
return new HashAggregationOperator(
115+
List.of(supplier(dataType).groupingAggregatorFactory(AggregatorMode.SINGLE, List.of(1))),
116+
() -> BlockHash.build(groupSpec, driverContext.blockFactory(), 16 * 1024, false),
117+
driverContext
118+
);
119+
}
120+
121+
private static AggregatorFunctionSupplier supplier(String dataType) {
122+
return switch (dataType) {
123+
case BYTES_REF -> new ValuesBytesRefAggregatorFunctionSupplier();
124+
case INT -> new ValuesIntAggregatorFunctionSupplier();
125+
case LONG -> new ValuesLongAggregatorFunctionSupplier();
126+
default -> throw new IllegalArgumentException("unsupported data type [" + dataType + "]");
127+
};
128+
}
129+
130+
private static void checkExpected(int groups, String dataType, Page page) {
131+
String prefix = String.format("[%s][%s]", groups, dataType);
132+
int positions = page.getPositionCount();
133+
if (positions != groups) {
134+
throw new IllegalArgumentException(prefix + " expected " + groups + " positions, got " + positions);
135+
}
136+
if (groups == 1) {
137+
checkUngrouped(prefix, dataType, page);
138+
return;
139+
}
140+
checkGrouped(prefix, groups, dataType, page);
141+
}
142+
143+
private static void checkGrouped(String prefix, int groups, String dataType, Page page) {
144+
LongVector groupsVector = page.<LongBlock>getBlock(0).asVector();
145+
for (int p = 0; p < groups; p++) {
146+
long group = groupsVector.getLong(p);
147+
if (group != p) {
148+
throw new IllegalArgumentException(prefix + "[" + p + "] expected group " + p + " but was " + groups);
149+
}
150+
}
151+
switch (dataType) {
152+
case BYTES_REF -> {
153+
// Build the expected values
154+
List<Set<BytesRef>> expected = new ArrayList<>(groups);
155+
for (int g = 0; g < groups; g++) {
156+
expected.add(new HashSet<>(KEYWORDS.length));
157+
}
158+
int blockLength = blockLength(groups);
159+
for (int p = 0; p < blockLength; p++) {
160+
expected.get(p % groups).add(KEYWORDS[p % KEYWORDS.length]);
161+
}
162+
163+
// Check them
164+
BytesRefBlock values = page.getBlock(1);
165+
for (int p = 0; p < groups; p++) {
166+
checkExpectedBytesRef(prefix, values, p, expected.get(p));
167+
}
168+
}
169+
case INT -> {
170+
// Build the expected values
171+
List<Set<Integer>> expected = new ArrayList<>(groups);
172+
for (int g = 0; g < groups; g++) {
173+
expected.add(new HashSet<>(UNIQUE_VALUES));
174+
}
175+
int blockLength = blockLength(groups);
176+
for (int p = 0; p < blockLength; p++) {
177+
expected.get(p % groups).add(p % KEYWORDS.length);
178+
}
179+
180+
// Check them
181+
IntBlock values = page.getBlock(1);
182+
for (int p = 0; p < groups; p++) {
183+
checkExpectedInt(prefix, values, p, expected.get(p));
184+
}
185+
}
186+
case LONG -> {
187+
// Build the expected values
188+
List<Set<Long>> expected = new ArrayList<>(groups);
189+
for (int g = 0; g < groups; g++) {
190+
expected.add(new HashSet<>(UNIQUE_VALUES));
191+
}
192+
int blockLength = blockLength(groups);
193+
for (int p = 0; p < blockLength; p++) {
194+
expected.get(p % groups).add((long) p % KEYWORDS.length);
195+
}
196+
197+
// Check them
198+
LongBlock values = page.getBlock(1);
199+
for (int p = 0; p < groups; p++) {
200+
checkExpectedLong(prefix, values, p, expected.get(p));
201+
}
202+
}
203+
default -> throw new IllegalArgumentException(prefix + " unsupported data type " + dataType);
204+
}
205+
}
206+
207+
private static void checkUngrouped(String prefix, String dataType, Page page) {
208+
switch (dataType) {
209+
case BYTES_REF -> {
210+
BytesRefBlock values = page.getBlock(0);
211+
checkExpectedBytesRef(prefix, values, 0, Set.of(KEYWORDS));
212+
}
213+
case INT -> {
214+
IntBlock values = page.getBlock(0);
215+
checkExpectedInt(prefix, values, 0, IntStream.range(0, UNIQUE_VALUES).boxed().collect(Collectors.toSet()));
216+
}
217+
case LONG -> {
218+
LongBlock values = page.getBlock(0);
219+
checkExpectedLong(prefix, values, 0, LongStream.range(0, UNIQUE_VALUES).boxed().collect(Collectors.toSet()));
220+
}
221+
default -> throw new IllegalArgumentException(prefix + " unsupported data type " + dataType);
222+
}
223+
}
224+
225+
private static int checkExpectedBlock(String prefix, Block values, int position, Set<?> expected) {
226+
int valueCount = values.getValueCount(position);
227+
if (valueCount != expected.size()) {
228+
throw new IllegalArgumentException(
229+
prefix + "[" + position + "] expected " + expected.size() + " values but count was " + valueCount
230+
);
231+
}
232+
return valueCount;
233+
}
234+
235+
private static void checkExpectedBytesRef(String prefix, BytesRefBlock values, int position, Set<BytesRef> expected) {
236+
int valueCount = checkExpectedBlock(prefix, values, position, expected);
237+
BytesRef scratch = new BytesRef();
238+
for (int i = values.getFirstValueIndex(position); i < valueCount; i++) {
239+
BytesRef v = values.getBytesRef(i, scratch);
240+
if (expected.contains(v) == false) {
241+
throw new IllegalArgumentException(prefix + "[" + position + "] expected " + v + " to be in " + expected);
242+
}
243+
}
244+
}
245+
246+
private static void checkExpectedInt(String prefix, IntBlock values, int position, Set<Integer> expected) {
247+
int valueCount = checkExpectedBlock(prefix, values, position, expected);
248+
for (int i = values.getFirstValueIndex(position); i < valueCount; i++) {
249+
Integer v = values.getInt(i);
250+
if (expected.contains(v) == false) {
251+
throw new IllegalArgumentException(prefix + "[" + position + "] expected " + v + " to be in " + expected);
252+
}
253+
}
254+
}
255+
256+
private static void checkExpectedLong(String prefix, LongBlock values, int position, Set<Long> expected) {
257+
int valueCount = checkExpectedBlock(prefix, values, position, expected);
258+
for (int i = values.getFirstValueIndex(position); i < valueCount; i++) {
259+
Long v = values.getLong(i);
260+
if (expected.contains(v) == false) {
261+
throw new IllegalArgumentException(prefix + "[" + position + "] expected " + v + " to be in " + expected);
262+
}
263+
}
264+
}
265+
266+
private static Page page(int groups, String dataType) {
267+
Block dataBlock = dataBlock(groups, dataType);
268+
if (groups == 1) {
269+
return new Page(dataBlock);
270+
}
271+
return new Page(groupingBlock(groups), dataBlock);
272+
}
273+
274+
private static Block dataBlock(int groups, String dataType) {
275+
int blockLength = blockLength(groups);
276+
return switch (dataType) {
277+
case BYTES_REF -> {
278+
try (BytesRefBlock.Builder builder = blockFactory.newBytesRefBlockBuilder(blockLength)) {
279+
for (int i = 0; i < blockLength; i++) {
280+
builder.appendBytesRef(KEYWORDS[i % KEYWORDS.length]);
281+
}
282+
yield builder.build();
283+
}
284+
}
285+
case INT -> {
286+
try (IntBlock.Builder builder = blockFactory.newIntBlockBuilder(blockLength)) {
287+
for (int i = 0; i < blockLength; i++) {
288+
builder.appendInt(i % UNIQUE_VALUES);
289+
}
290+
yield builder.build();
291+
}
292+
}
293+
case LONG -> {
294+
try (LongBlock.Builder builder = blockFactory.newLongBlockBuilder(blockLength)) {
295+
for (int i = 0; i < blockLength; i++) {
296+
builder.appendLong(i % UNIQUE_VALUES);
297+
}
298+
yield builder.build();
299+
}
300+
}
301+
default -> throw new IllegalArgumentException("unsupported data type " + dataType);
302+
};
303+
}
304+
305+
private static Block groupingBlock(int groups) {
306+
int blockLength = blockLength(groups);
307+
try (LongVector.Builder builder = blockFactory.newLongVectorBuilder(blockLength)) {
308+
for (int i = 0; i < blockLength; i++) {
309+
builder.appendLong(i % groups);
310+
}
311+
return builder.build().asBlock();
312+
}
313+
}
314+
315+
@Benchmark
316+
public void run() {
317+
run(groups, dataType, OP_COUNT);
318+
}
319+
320+
private static void run(int groups, String dataType, int opCount) {
321+
DriverContext driverContext = driverContext();
322+
try (Operator operator = operator(driverContext, groups, dataType)) {
323+
Page page = page(groups, dataType);
324+
for (int i = 0; i < opCount; i++) {
325+
operator.addInput(page.shallowCopy());
326+
}
327+
operator.finish();
328+
checkExpected(groups, dataType, operator.getOutput());
329+
}
330+
}
331+
332+
static DriverContext driverContext() {
333+
return new DriverContext(BigArrays.NON_RECYCLING_INSTANCE, blockFactory);
334+
}
335+
336+
static int blockLength(int groups) {
337+
return Math.max(MIN_BLOCK_LENGTH, groups);
338+
}
339+
}

docs/changelog/123073.yaml

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 123073
2+
summary: Speed up VALUES for many buckets
3+
area: ES|QL
4+
type: bug
5+
issues: []

0 commit comments

Comments
 (0)